lowering.py 232 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298
  1. # mypy: allow-untyped-defs
  2. from __future__ import annotations
  3. import contextlib
  4. import dataclasses
  5. import functools
  6. import itertools
  7. import logging
  8. import math
  9. import operator
  10. import os
  11. import textwrap
  12. import warnings
  13. from collections import defaultdict
  14. from collections.abc import Iterable, Sequence
  15. from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
  16. from typing_extensions import ParamSpec
  17. from unittest.mock import patch
  18. import sympy
  19. import torch
  20. import torch.ao.quantization.fx._decomposed
  21. import torch.fx
  22. import torch.utils._pytree as pytree
  23. from torch._dynamo.utils import counters
  24. from torch._higher_order_ops.associative_scan import associative_scan_op
  25. from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
  26. from torch._library.utils import get_layout_constraint_tag
  27. from torch._prims_common import (
  28. canonicalize_dim,
  29. canonicalize_dims,
  30. check,
  31. dtype_to_type,
  32. elementwise_dtypes,
  33. ELEMENTWISE_TYPE_PROMOTION_KIND,
  34. get_computation_dtype,
  35. is_boolean_dtype,
  36. is_float_dtype,
  37. is_integer_dtype,
  38. Number,
  39. )
  40. from torch.fx.experimental.sym_node import magic_methods, method_to_operator
  41. from torch.fx.experimental.symbolic_shapes import (
  42. free_unbacked_symbols,
  43. has_free_unbacked_symbols,
  44. resolve_unbacked_bindings,
  45. )
  46. from torch.utils._ordered_set import OrderedSet
  47. from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing
  48. from .._dynamo.utils import import_submodule
  49. from . import config, inductor_prims, ir, test_operators # NOQA: F401
  50. from .decomposition import decompositions, get_decompositions
  51. from .ir import (
  52. BaseView,
  53. DtypeView,
  54. ExpandView,
  55. IndexingConstant,
  56. IRNode,
  57. is_triton,
  58. MutableBox,
  59. OnlineSoftmaxReduction,
  60. ops_wrapper,
  61. PermuteView,
  62. Pointwise,
  63. Reduction,
  64. ShapeAsConstantBuffer,
  65. SqueezeView,
  66. TensorBox,
  67. validate_ir,
  68. View,
  69. )
  70. from .utils import (
  71. ceildiv,
  72. decode_device,
  73. is_dynamic,
  74. is_gpu,
  75. is_pointwise_use,
  76. is_view,
  77. needs_fallback_due_to_atomic_add_limitations,
  78. pad_listlike,
  79. register_op_dtype_propagation_rules,
  80. register_op_requires_libdevice_fp64,
  81. sympy_product,
  82. use_scatter_fallback,
  83. )
  84. from .virtualized import ops, V
  85. if TYPE_CHECKING:
  86. from .ops_handler import ReductionType
  87. _T = TypeVar("_T")
  88. _P = ParamSpec("_P")
  89. # TODO(jansel): we should implement decomps or lowerings for these
  90. # https://github.com/pytorch/torchdynamo/issues/327
  91. FALLBACK_ALLOW_LIST = OrderedSet(
  92. [
  93. "torchvision::roi_align",
  94. "aten::index_add",
  95. ]
  96. )
  97. log = logging.getLogger(__name__)
  98. lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
  99. # Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
  100. _maybe_layout_constraints: dict[
  101. torch._ops.OpOverload, Optional[Callable[..., Any]]
  102. ] = {}
  103. fallbacks = OrderedSet[torch._ops.OpOverload]()
  104. aten = torch.ops.aten
  105. tr_c10d = torch.ops.tr_c10d
  106. prims = torch.ops.prims
  107. needs_realized_inputs = OrderedSet[torch._ops.OpOverload]()
  108. foreach_ops = OrderedSet[torch._ops.OpOverload](
  109. [torch._higher_order_ops._foreach_map] # type: ignore[list-item]
  110. )
  111. # TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload
  112. # so why is it in foreach_ops?
  113. inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]()
  114. inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
  115. quantized_decomposed = torch.ops.quantized_decomposed
  116. def cur_node_has_non_foreach_users():
  117. for node in V.graph.current_node.users:
  118. for user in node.users:
  119. if not (user.op == "call_function" and (user.target in foreach_ops)):
  120. return True
  121. return False
  122. # group by device, whether any of the inputs are dynamic
  123. # note arg_pairs may or may not be a pair
  124. # foreach_map for example just passes output buffers here
  125. def group_foreach_args(arg_pairs: Iterable[Union[tuple[Any, Any], Any]]):
  126. out = defaultdict(list)
  127. unpack_args = False
  128. for i, args in enumerate(arg_pairs):
  129. if not isinstance(args, Iterable):
  130. unpack_args = True
  131. args = (args,)
  132. use_foreach = (
  133. not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
  134. )
  135. device = None
  136. for t in args:
  137. if isinstance(t, TensorBox):
  138. device = t.data.get_device()
  139. break
  140. assert device is not None, "foreach op should have at least one tensor arg"
  141. if unpack_args:
  142. (args,) = args
  143. out[(device, use_foreach)].append((i, args))
  144. return out
  145. def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]:
  146. """Get layout constraints. Returns None if there are no layout constraints."""
  147. if not isinstance(fn, torch._ops.OpOverload):
  148. # Only OpOverloads have layout constraints.
  149. return None
  150. if maybe_layout_tag := get_layout_constraint_tag(fn, with_default=False):
  151. return tag_to_layout_constraint(maybe_layout_tag)
  152. if fn in _maybe_layout_constraints:
  153. return _maybe_layout_constraints[fn]
  154. return None
  155. def tag_to_layout_constraint(tag):
  156. if tag == torch._C.Tag.needs_exact_strides:
  157. return constrain_to_fake_tensors
  158. if tag == torch._C.Tag.needs_contiguous_strides: # type: ignore[attr-defined]
  159. return require_contiguous_strides
  160. if tag == torch._C.Tag.needs_fixed_stride_order:
  161. return constrain_to_fx_strides
  162. if tag == torch._C.Tag.flexible_layout:
  163. return None
  164. raise AssertionError(f"Unknown layout constraint tag: {tag}")
  165. def assert_nyi(cond, msg):
  166. if not cond:
  167. raise NotImplementedError(f"inductor does not support {msg}")
  168. def add_needs_realized_inputs(fn):
  169. if isinstance(fn, (list, set, tuple, OrderedSet)): # noqa: set_linter
  170. return [add_needs_realized_inputs(x) for x in fn]
  171. needs_realized_inputs.add(fn)
  172. if isinstance(fn, torch._ops.OpOverloadPacket):
  173. needs_realized_inputs.update(
  174. getattr(fn, overload) for overload in fn.overloads()
  175. )
  176. def add_layout_constraint(fn, constraint):
  177. if isinstance(fn, torch._ops.OpOverloadPacket):
  178. for overload in fn.overloads():
  179. _maybe_layout_constraints[getattr(fn, overload)] = constraint
  180. else:
  181. _maybe_layout_constraints[fn] = constraint
  182. add_needs_realized_inputs(
  183. [
  184. aten.as_strided,
  185. aten.as_strided_copy,
  186. aten.avg_pool2d,
  187. aten.avg_pool2d_backward,
  188. aten.bmm,
  189. aten.convolution,
  190. aten.convolution_backward,
  191. aten.max_pool2d_with_indices,
  192. aten.max_pool3d_with_indices,
  193. aten.max_pool2d_with_indices_backward,
  194. aten.mm,
  195. aten.upsample_nearest2d,
  196. aten._upsample_nearest_exact2d,
  197. aten._int_mm,
  198. ]
  199. )
  200. # TODO(jansel): ezyang says we won't need this in the future, try removing it
  201. # based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
  202. DTYPE_ID_LOOKUP = {
  203. 0: torch.uint8,
  204. 1: torch.int8,
  205. 2: torch.int16,
  206. 3: torch.int32,
  207. 4: torch.int64,
  208. 5: torch.float16,
  209. 6: torch.float32,
  210. 7: torch.float64,
  211. 8: torch.complex32,
  212. 9: torch.complex64,
  213. 10: torch.complex32,
  214. 11: torch.bool,
  215. 15: torch.bfloat16,
  216. # TODO(jansel): add quantized types?
  217. # _(c10::qint8, QInt8) /* 12 */
  218. # _(c10::quint8, QUInt8) /* 13 */
  219. # _(c10::qint32, QInt32) /* 14 */
  220. # _(c10::quint4x2, QUInt4x2) /* 16 */
  221. # _(c10::quint2x4, QUInt2x4) /* 17 */
  222. }
  223. def decode_dtype(dtype: int):
  224. if not isinstance(dtype, int):
  225. return dtype
  226. assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
  227. dtype = DTYPE_ID_LOOKUP[dtype]
  228. return dtype
  229. def is_integer_type(x):
  230. if isinstance(x, TensorBox):
  231. return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  232. elif isinstance(x, sympy.Expr):
  233. return x.is_integer is True # type: ignore[attr-defined]
  234. else:
  235. return isinstance(x, int)
  236. def is_boolean_type(x):
  237. if isinstance(x, TensorBox):
  238. return is_boolean_dtype(x.get_dtype())
  239. else:
  240. return isinstance(x, bool)
  241. def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
  242. def construct_input(inp):
  243. if isinstance(inp, (Number, sympy.Basic)):
  244. return inp
  245. else:
  246. dim = len(inp.get_size())
  247. # construct a tmp tensor to feed into torch.result_type
  248. return torch.zeros([1] * dim, dtype=inp.get_dtype())
  249. inps = [construct_input(arg) for arg in args]
  250. _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind)
  251. return dtype
  252. def get_overloads(aten_fn):
  253. if not isinstance(aten_fn, (list, tuple)):
  254. aten_fn = [aten_fn]
  255. else:
  256. aten_fn = list(aten_fn)
  257. for fn in list(aten_fn):
  258. if isinstance(fn, torch._ops.OpOverloadPacket):
  259. for overload in fn.overloads():
  260. other_fn = getattr(fn, overload)
  261. if other_fn not in lowerings:
  262. aten_fn.append(other_fn)
  263. return aten_fn
  264. def in_namespace(op, namespace):
  265. if isinstance(op, torch._ops.OpOverloadPacket):
  266. return namespace in op._qualified_op_name
  267. elif isinstance(op, torch._ops.OpOverload):
  268. return namespace in op.name()
  269. return False
  270. def maybe_copy_cpu_scalar(x: TensorBox, device: torch.device) -> TensorBox:
  271. """
  272. Copy cpu scalar if doesn't not match with given `device`
  273. """
  274. if not isinstance(x.data, ir.ReinterpretView) or has_free_unbacked_symbols(
  275. x.get_size()
  276. ):
  277. return x
  278. size = [V.graph.sizevars.size_hint_or_throw(s) for s in x.get_size()]
  279. cur_device = x.get_device()
  280. if (
  281. cur_device is not None
  282. and cur_device.type == "cpu"
  283. and cur_device != device
  284. and (len(size) == 0 or (len(size) == 1 and size[0] == 1))
  285. ):
  286. return TensorBox(ir.StorageBox(ir.DeviceCopy.create(x, cur_device, False)))
  287. return x
  288. def transform_args(
  289. args: list[Any],
  290. kwargs: dict[str, Any],
  291. broadcast: bool,
  292. type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
  293. convert_input_to_bool: bool,
  294. ) -> tuple[list[Any], dict[str, Any]]:
  295. """
  296. Transforms arguments for broadcasting and type promotion
  297. """
  298. args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
  299. kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)]
  300. # check that there's something to transform
  301. if not args_indices and not kwargs_indices:
  302. return args, kwargs
  303. if type_promotion_kind or convert_input_to_bool:
  304. if convert_input_to_bool:
  305. dtype = torch.bool
  306. else:
  307. # FIXME this is a crude approximation for promoting args
  308. promoting_args = [
  309. a
  310. for a in args
  311. if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype")
  312. ]
  313. # only consider tensor kwargs for promotion, for now
  314. promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
  315. dtype = get_promoted_dtype(
  316. *promoting_args,
  317. type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
  318. )
  319. device = (
  320. args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]]
  321. ).get_device()
  322. for i in args_indices:
  323. args[i] = maybe_copy_cpu_scalar(args[i], device)
  324. for k in kwargs_indices:
  325. kwargs[k] = maybe_copy_cpu_scalar(kwargs[k], device)
  326. # sometimes args are an immutable list so we can't mutate them
  327. def promote(arg):
  328. if isinstance(arg, TensorBox):
  329. return to_dtype(arg, dtype)
  330. elif isinstance(arg, ir.Constant):
  331. return ir.Constant(value=arg.value, dtype=dtype, device=device)
  332. else:
  333. return arg
  334. args = [promote(a) for a in args]
  335. kwargs = {k: promote(v) for k, v in kwargs.items()}
  336. if broadcast:
  337. broadcasted = broadcast_tensors(
  338. *list(
  339. itertools.chain(
  340. (args[i] for i in args_indices),
  341. (kwargs[k] for k in kwargs_indices),
  342. )
  343. )
  344. )
  345. size = list(broadcasted[0].get_size())
  346. for i, x in zip(args_indices, broadcasted[: len(args_indices)]):
  347. args[i] = x
  348. for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]):
  349. kwargs[k] = x
  350. for i in range(len(args)):
  351. if isinstance(args[i], ir.Constant):
  352. args[i] = ExpandView.create(args[i], size)
  353. for k in kwargs:
  354. if isinstance(kwargs[k], ir.Constant):
  355. kwargs[k] = ExpandView.create(kwargs[k], size)
  356. return args, kwargs
  357. def _register_foreach_lowering(aten_fn, decomp_fn):
  358. """
  359. Add a foreach lowering to lowerings dict.
  360. Arguments:
  361. aten_fn: torch.ops.aten.* fn we are lowering
  362. decomp_fn: alternate implementation on our IR
  363. broadcast: True to apply broadcasting to tensor inputs
  364. type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
  365. convert_input_to_bool: some logical ops require inputs are converted to bool
  366. """
  367. @functools.wraps(decomp_fn)
  368. def wrapped(*args, **kwargs):
  369. assert len(args) <= 2
  370. out = decomp_fn(*args, **kwargs)
  371. validate_ir(out)
  372. return out
  373. aten_fns = get_overloads(aten_fn)
  374. foreach_ops.update(aten_fns)
  375. lowerings.update(dict.fromkeys(aten_fns, wrapped))
  376. return wrapped
  377. def _register_lowering(
  378. aten_fn,
  379. decomp_fn,
  380. broadcast,
  381. type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
  382. convert_input_to_bool,
  383. lowering_dict,
  384. ):
  385. """
  386. Add a lowering to lowerings dict
  387. Arguments:
  388. aten_fn: torch.ops.aten.* fn we are lowering
  389. decomp_fn: alternate implementation on our IR
  390. broadcast: True to apply broadcasting to tensor inputs
  391. type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
  392. convert_input_to_bool: some logical ops require inputs are converted to bool
  393. """
  394. @functools.wraps(decomp_fn)
  395. def wrapped(*args, **kwargs):
  396. args: list[Any] = list(args)
  397. kwargs: dict[str, Any] = dict(kwargs)
  398. unpacked = False
  399. # TODO maybe we need to use pytrees here
  400. if len(args) == 1 and isinstance(args[0], (list, tuple)):
  401. unpacked = True
  402. args = list(args[0])
  403. if not all(
  404. (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
  405. ):
  406. # explicitly assert for "out=" ops for better error messages
  407. assert not any(x == "out" for x in kwargs.keys()), (
  408. "out= ops aren't yet supported"
  409. )
  410. args, kwargs = transform_args(
  411. args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
  412. )
  413. if unpacked:
  414. args = [args]
  415. out = decomp_fn(*args, **kwargs)
  416. validate_ir(out)
  417. return out
  418. aten_fn = get_overloads(aten_fn)
  419. lowering_dict.update(dict.fromkeys(aten_fn, wrapped))
  420. return wrapped
  421. def register_lowering(
  422. aten_fn,
  423. broadcast=False,
  424. type_promotion_kind: Optional[
  425. ELEMENTWISE_TYPE_PROMOTION_KIND
  426. ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  427. convert_input_to_bool=False,
  428. lowering_dict=lowerings,
  429. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  430. """
  431. Shim to support decorator syntax.
  432. """
  433. return functools.partial(
  434. _register_lowering,
  435. aten_fn,
  436. broadcast=broadcast,
  437. type_promotion_kind=type_promotion_kind,
  438. convert_input_to_bool=convert_input_to_bool,
  439. lowering_dict=lowering_dict,
  440. )
  441. def broadcast_symbolic_shapes(a, b):
  442. """
  443. Broadcasting logic based on symbolic shapes.
  444. We give the shapes 0 and 1 concrete values, while all other shapes
  445. are symbolic sympy formulas.
  446. """
  447. output = []
  448. for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One):
  449. if V.graph.sizevars.is_size_one_or_false(y):
  450. output.append(x)
  451. elif V.graph.sizevars.is_size_one_or_false(x):
  452. output.append(y)
  453. else:
  454. V.graph.sizevars.check_equals(x, y)
  455. if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols):
  456. output.append(y) # prefer shorter formula
  457. else:
  458. output.append(x)
  459. return tuple(reversed(output))
  460. def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
  461. assert override_return_dtype is None or type_promotion_kind is None, (
  462. "only one of override_return_dtype or type_promotion_kind may be given"
  463. )
  464. if override_return_dtype is None and type_promotion_kind is None:
  465. type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  466. if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs):
  467. return inputs
  468. if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs):
  469. dtype = override_return_dtype or get_promoted_dtype(
  470. *inputs, type_promotion_kind=type_promotion_kind
  471. )
  472. def const_func(x):
  473. if isinstance(x, sympy.Basic):
  474. return ir.IndexingConstant(
  475. index=x, dtype=dtype, device=decode_device(None)
  476. )
  477. else:
  478. return ir.Constant(value=x, dtype=dtype, device=decode_device(None))
  479. return [const_func(x) for x in inputs]
  480. ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant)))
  481. out = []
  482. for x in inputs:
  483. if isinstance(x, (int, float)):
  484. out.append(
  485. ExpandView.create(
  486. ir.Constant(
  487. value=x, dtype=ex.get_dtype(), device=ex.get_device_or_error()
  488. ),
  489. list(ex.get_size()),
  490. )
  491. )
  492. elif isinstance(x, sympy.Basic):
  493. out.append(
  494. ExpandView.create(
  495. IndexingConstant(
  496. index=x, dtype=ex.get_dtype(), device=ex.get_device_or_error()
  497. ),
  498. list(ex.get_size()),
  499. )
  500. )
  501. else:
  502. out.append(x)
  503. return out
  504. def make_pointwise(
  505. fn,
  506. override_return_dtype=None,
  507. override_device=None,
  508. override_fn_when_input_bool=None,
  509. allow_alpha=False,
  510. triton_fallback=None,
  511. ):
  512. def inner(*inputs: TensorBox, alpha=None):
  513. if triton_fallback is not None and any(
  514. isinstance(inp, IRNode) and is_triton(inp) for inp in inputs
  515. ):
  516. assert not allow_alpha # not implemented
  517. return triton_fallback(*inputs)
  518. inputs = promote_constants(inputs, override_return_dtype)
  519. if allow_alpha:
  520. if alpha is not None and alpha != 1:
  521. inputs = list(inputs)
  522. inputs[-1] = mul(inputs[-1], alpha)
  523. else:
  524. assert alpha is None
  525. loaders = [x.make_loader() for x in inputs]
  526. ranges = inputs[0].get_size()
  527. dtype = override_return_dtype or inputs[0].get_dtype()
  528. for other in inputs[1:]:
  529. assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
  530. other.get_size()
  531. ), f"ndim mismatch {fn} {ranges} {other.get_size()}"
  532. # in tracing, we will annotate pointwise nodes that correspond to the output of
  533. # a pointwise node that would have been run in eager. intermediary pointwise nodes
  534. # during decompositions are not annotated.
  535. low_pr_fp = (torch.bfloat16, torch.float16)
  536. emulate_precision_casts = (
  537. V.graph is not None
  538. and getattr(V.graph, "current_node", None) is not None
  539. and V.graph.current_node.meta is not None
  540. and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False)
  541. and dtype in low_pr_fp
  542. )
  543. def inner_fn(index):
  544. assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
  545. if dtype == torch.bool and override_fn_when_input_bool is not None:
  546. return override_fn_when_input_bool(*[load(index) for load in loaders])
  547. else:
  548. inputs_loaded = []
  549. for inp_index, load in enumerate(loaders):
  550. out = load(index)
  551. inp_dtype = inputs[inp_index].get_dtype()
  552. if emulate_precision_casts and inp_dtype in low_pr_fp:
  553. downcast = ops.to_dtype(out, inp_dtype, use_compute_types=False)
  554. out = ops.to_dtype(downcast, inp_dtype)
  555. inputs_loaded.append(out)
  556. out = fn(*inputs_loaded)
  557. if emulate_precision_casts:
  558. # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here,
  559. # then upcasting again, to emulate casts that eager would do.
  560. downcast = ops.to_dtype(out, dtype, use_compute_types=False)
  561. return ops.to_dtype(downcast, dtype)
  562. return out
  563. if not override_device:
  564. device = None
  565. for i in inputs:
  566. if is_gpu(i.get_device().type):
  567. device = i.get_device()
  568. break
  569. if not device:
  570. device = inputs[0].get_device()
  571. device = override_device or device
  572. return Pointwise.create(
  573. device=device, # type: ignore[arg-type]
  574. dtype=dtype,
  575. inner_fn=inner_fn,
  576. ranges=ranges,
  577. )
  578. return inner
  579. def make_foreach_pointwise(pw_fn, allow_alpha=False):
  580. def inner(*inputs: list[list[TensorBox]], alpha=1):
  581. realize_outputs = (
  582. len(V.graph.current_node.users) == 0
  583. or V.graph.current_node.target in inplace_foreach_ops
  584. or cur_node_has_non_foreach_users()
  585. )
  586. a_list_input = None
  587. for input in inputs:
  588. if isinstance(input, (list, tuple)):
  589. a_list_input = input
  590. break
  591. assert a_list_input is not None, (
  592. "at least one input must be a list to a foreach op"
  593. )
  594. # broadcast scalar inputs to match length of list inputs
  595. broadcast_inputs = []
  596. for input in inputs:
  597. if not isinstance(input, (list, tuple)):
  598. broadcast_inputs.append([input] * len(a_list_input))
  599. else:
  600. broadcast_inputs.append(input)
  601. groups = group_foreach_args(zip(*broadcast_inputs))
  602. outputs = [None] * len(a_list_input)
  603. for (device, use_foreach), group in groups.items():
  604. operation_list: list[str] = []
  605. for (
  606. output_ind,
  607. args,
  608. ) in group:
  609. if allow_alpha:
  610. output = pw_fn(*args, alpha=alpha)
  611. else:
  612. output = pw_fn(*args)
  613. outputs[output_ind] = output
  614. if (
  615. V.graph.has_feature(device, BackendFeature.FOREACH)
  616. and use_foreach
  617. and realize_outputs
  618. ):
  619. output.realize()
  620. operation_list.append(output.get_operation_name())
  621. if operation_list:
  622. V.graph.register_operation_list(operation_list)
  623. assert all(x is not None for x in outputs)
  624. return outputs
  625. return inner
  626. def to_dtype(
  627. x: Union[TensorBox, ShapeAsConstantBuffer], dtype: torch.dtype, copy: bool = False
  628. ):
  629. src_dtype = x.get_dtype()
  630. if src_dtype == dtype:
  631. return clone(x) if copy else x
  632. def _to_dtype(x):
  633. return ops.to_dtype(x, dtype, src_dtype=src_dtype)
  634. return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
  635. @register_lowering(torch._higher_order_ops._foreach_map, type_promotion_kind=None)
  636. def _foreach_map(subgraph, *args, **kwargs):
  637. """
  638. This lowers an invocation of foreach_map
  639. The way this works is that an arbitrary N-arg func is provided by the user, looped over by the
  640. polyfill with the same semantics as a foreach op (a loop applying an n-ary function to n args)
  641. and then traced into a subgraph by dynamo.
  642. This code allows us to inline the subgraph into the main graph lowering using the PontwiseSubgraphLowering.
  643. The graph outputs represent the vertically fused sequence of ops, and then register_operation_list
  644. below registers the buffers as horizontally fuseable in the scheduler.
  645. """
  646. from .subgraph_lowering import PointwiseSubgraphLowering
  647. inputs = args
  648. gm = subgraph.graph_module
  649. pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
  650. with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
  651. pw_subgraph.run(*inputs)
  652. sub_outputs = pw_subgraph.graph_outputs
  653. # group outputs by device and register as foreach
  654. assert sub_outputs # mypy lol
  655. groups = group_foreach_args(sub_outputs)
  656. outputs = [None] * len(sub_outputs)
  657. for (device, use_foreach), group in groups.items():
  658. operation_list: list[str] = []
  659. for (
  660. output_ind,
  661. output,
  662. ) in group:
  663. outputs[output_ind] = output
  664. if V.graph.has_feature(device, BackendFeature.FOREACH) and use_foreach:
  665. output.realize()
  666. operation_list.append(output.get_operation_name())
  667. if operation_list:
  668. V.graph.register_operation_list(operation_list)
  669. assert all(x is not None for x in outputs)
  670. return outputs
  671. @register_lowering(prims.convert_element_type, type_promotion_kind=None)
  672. def _convert_element_type(x: TensorBox, dtype: torch.dtype):
  673. if dtype.is_complex or x.get_dtype().is_complex:
  674. if x.get_size():
  675. # Decompose since aa aten fallback is more friendly for c++ codegen.
  676. # This decomposition doesn't work for empty tensor, which needs more investigation.
  677. dst = empty_like(x, dtype=dtype)
  678. ir.InplaceCopyFallback.create(dst, x)
  679. return dst
  680. else:
  681. return fallback_handler(
  682. prims.convert_element_type.default, add_to_fallback_set=False
  683. )(x, dtype)
  684. return to_dtype(x, dtype, copy=True)
  685. def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False):
  686. x_dtype = x.get_dtype()
  687. if x_dtype == dtype:
  688. return clone(x) if copy else x
  689. def _get_primitive_bitwidth(dtype):
  690. if dtype.is_floating_point:
  691. return torch.finfo(dtype).bits
  692. else:
  693. return torch.iinfo(dtype).bits
  694. src_bits = _get_primitive_bitwidth(x_dtype)
  695. dst_bits = _get_primitive_bitwidth(dtype)
  696. if src_bits != dst_bits:
  697. # fallback to aten eager implementation for differing bitwidths
  698. return fallback_handler(aten.view.dtype)(x, dtype)
  699. else:
  700. return TensorBox(DtypeView.create(x, dtype))
  701. @register_lowering(aten.view.dtype, type_promotion_kind=None)
  702. def _view_dtype(x: TensorBox, dtype: torch.dtype):
  703. if dtype.is_complex or x.get_dtype().is_complex:
  704. return TensorBox.create(
  705. ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype)
  706. )
  707. return to_dtype_bitcast(x, dtype)
  708. def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False):
  709. device = decode_device(device)
  710. if x.get_device() == device:
  711. return clone(x) if copy else x
  712. return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking))
  713. @register_lowering(prims.device_put, type_promotion_kind=None)
  714. def _device_put(x: TensorBox, device: torch.device, non_blocking=False):
  715. return to_device(x, device, copy=True, non_blocking=non_blocking)
  716. def register_pointwise(
  717. aten_fn,
  718. name=None,
  719. broadcast=True,
  720. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  721. convert_input_to_bool=False,
  722. override_return_dtype=None,
  723. override_fn_when_input_bool=None,
  724. allow_alpha=False,
  725. triton_fallback=None,
  726. ):
  727. """A pointwise function that maps ops.{name} to inputs"""
  728. name = name or aten_fn.__name__
  729. fn = ops_wrapper(name)
  730. register_op_dtype_propagation_rules(
  731. name, type_promotion_kind, override_return_dtype
  732. )
  733. if override_fn_when_input_bool is not None:
  734. override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool)
  735. fn = make_pointwise(
  736. fn,
  737. override_return_dtype=override_return_dtype,
  738. override_fn_when_input_bool=override_fn_when_input_bool,
  739. allow_alpha=allow_alpha,
  740. triton_fallback=triton_fallback,
  741. )
  742. fn = register_lowering(
  743. aten_fn,
  744. broadcast=broadcast,
  745. type_promotion_kind=type_promotion_kind,
  746. convert_input_to_bool=convert_input_to_bool,
  747. )(fn)
  748. if hasattr(prims, name):
  749. register_lowering(
  750. getattr(prims, name),
  751. type_promotion_kind=None,
  752. convert_input_to_bool=convert_input_to_bool,
  753. )(fn)
  754. return fn
  755. def register_frexp():
  756. """A pointwise function that maps ops.frexp to inputs"""
  757. name = "frexp"
  758. frexp = ops_wrapper("frexp")
  759. def frexp0(*args, **kwargs):
  760. return frexp(*args, **kwargs)[0] # type: ignore[index]
  761. def frexp1(*args, **kwargs):
  762. return frexp(*args, **kwargs)[1] # type: ignore[index]
  763. pw_fns = [
  764. make_pointwise(frexp0),
  765. make_pointwise(frexp1, override_return_dtype=torch.int32),
  766. ]
  767. def fn(*args, **kwargs):
  768. return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs)
  769. fn = register_lowering(
  770. aten.frexp,
  771. )(fn)
  772. if hasattr(prims, name):
  773. register_lowering(
  774. getattr(prims, name),
  775. type_promotion_kind=None,
  776. )(fn)
  777. return fn
  778. register_frexp()
  779. def register_foreach_pointwise(
  780. aten_fn,
  781. pointwise_lowering_fn,
  782. allow_alpha=False,
  783. ):
  784. fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha)
  785. fn = _register_foreach_lowering(aten_fn, fn)
  786. return fn
  787. @register_lowering(aten.where, broadcast=False, type_promotion_kind=None)
  788. def where(cond, a, b):
  789. def fn(*args):
  790. return ops.where(*args)
  791. if isinstance(a, (float, int)):
  792. a = constant_like(a)(b)
  793. if isinstance(b, (float, int)):
  794. b = constant_like(b)(a)
  795. args = [cond, a, b]
  796. dtype = get_promoted_dtype(
  797. args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  798. )
  799. indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
  800. for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
  801. args[i] = x
  802. for i in range(len(args)):
  803. if isinstance(args[i], ir.Constant):
  804. args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
  805. return make_pointwise(fn, override_return_dtype=dtype)(
  806. args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype)
  807. )
  808. @register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None)
  809. def broadcast_tensors(*inputs):
  810. if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
  811. return broadcast_tensors(*inputs[0])
  812. target: list[sympy.Expr] = functools.reduce(
  813. broadcast_symbolic_shapes, [x.get_size() for x in inputs], []
  814. )
  815. outputs = []
  816. for x in inputs:
  817. sizes = x.get_size()
  818. if len(sizes) != len(target) or any(
  819. V.graph.sizevars.is_size_one_or_false(a)
  820. != V.graph.sizevars.is_size_one_or_false(b)
  821. for a, b in zip(sizes, target)
  822. ):
  823. x = expand(x, target)
  824. outputs.append(x)
  825. return outputs
  826. @register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
  827. def nop(x):
  828. return x # AOT autograd handles this for us
  829. if hasattr(aten, "lift_fresh"):
  830. register_lowering(aten.lift_fresh)(nop)
  831. @register_lowering(aten.squeeze, type_promotion_kind=None)
  832. def squeeze(x, dim=None):
  833. assert isinstance(x, TensorBox)
  834. if dim is None:
  835. return TensorBox(SqueezeView.create(x.data))
  836. dim = (
  837. V.graph.sizevars.guard_int(dim)
  838. if isinstance(dim, (int, sympy.Expr))
  839. else tuple(V.graph.sizevars.guard_int(d) for d in dim)
  840. )
  841. dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload]
  842. dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
  843. new_shape = []
  844. for d, s in enumerate(x.get_size()):
  845. if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))):
  846. new_shape.append(s)
  847. # squeeze does nothing if the size isn't 1
  848. return view(x, new_shape) if new_shape != x.get_size() else x
  849. @register_lowering(aten.squeeze_copy, type_promotion_kind=None)
  850. def squeeze_copy(x, dim=None):
  851. return clone(squeeze(x, dim))
  852. @register_lowering([aten.squeeze_])
  853. def squeeze_(x, dim=None):
  854. val = squeeze(x, dim)
  855. assert isinstance(x, TensorBox)
  856. assert isinstance(val, TensorBox)
  857. x.data = val.data
  858. return x
  859. @register_lowering(aten.isinf)
  860. def isinf(x):
  861. if is_integer_type(x):
  862. return full_like(x, False, dtype=torch.bool)
  863. fn = ops_wrapper("isinf")
  864. return make_pointwise(fn, override_return_dtype=torch.bool)(x)
  865. @register_lowering(aten.isnan)
  866. def isnan(x):
  867. if is_integer_type(x):
  868. return full_like(x, False, dtype=torch.bool)
  869. fn = ops_wrapper("isnan")
  870. return make_pointwise(fn, override_return_dtype=torch.bool)(x)
  871. @register_lowering(aten.ceil)
  872. def ceil(x):
  873. if is_integer_type(x):
  874. return clone(x)
  875. fn = ops_wrapper("ceil")
  876. return make_pointwise(fn)(x)
  877. @register_lowering(aten.floor)
  878. def floor(x):
  879. if is_integer_type(x):
  880. return clone(x)
  881. fn = ops_wrapper("floor")
  882. return make_pointwise(fn)(x)
  883. @register_lowering(aten.round.default)
  884. def round(x):
  885. if is_integer_type(x):
  886. return clone(x)
  887. else:
  888. fn = ops_wrapper("round")
  889. return make_pointwise(fn)(x)
  890. @register_lowering(aten.trunc)
  891. def trunc(x):
  892. if is_integer_type(x):
  893. return clone(x)
  894. fn = ops_wrapper("trunc")
  895. return make_pointwise(fn)(x)
  896. @register_lowering(aten.expand, type_promotion_kind=None)
  897. def expand(x, sizes):
  898. (x,) = promote_constants([x])
  899. if isinstance(x, ir.BaseConstant):
  900. return ExpandView.create(x, tuple(sizes))
  901. assert isinstance(x, TensorBox)
  902. assert isinstance(sizes, (list, tuple))
  903. if tuple(x.get_size()) == tuple(sizes):
  904. return x
  905. if not free_unbacked_symbols(x.get_size()):
  906. x_size_product = V.graph.sizevars.size_hint_or_throw(
  907. sympy_product(x.get_size())
  908. )
  909. # TODO: It would be better to realize the input if any of its sizes
  910. # are unbacked, because typically the size will be non-zero. However,
  911. # this cannot be done directly as below as we'll choke on the size_hint
  912. # here
  913. if x_size_product > 0 and not free_unbacked_symbols(sizes):
  914. # maybe realize input before broadcasting it
  915. x.mark_reuse(
  916. V.graph.sizevars.size_hint_or_throw(sympy_product(sizes))
  917. // x_size_product
  918. )
  919. return TensorBox(ExpandView.create(x.data, tuple(sizes)))
  920. @register_lowering(prims.broadcast_in_dim, type_promotion_kind=None)
  921. def broadcast_in_dim(a, shape, broadcast_dimensions):
  922. s = list(shape)
  923. for broadcast_dimension in broadcast_dimensions:
  924. s[broadcast_dimension] = -1
  925. v = a
  926. for idx, x in enumerate(s):
  927. if x != -1:
  928. v = unsqueeze(v, idx)
  929. return expand(v, shape)
  930. @register_lowering(aten.expand_as, type_promotion_kind=None)
  931. def expand_as(x, y):
  932. return expand(x, y.get_size())
  933. @register_lowering(aten.repeat)
  934. def repeat(x, repeats):
  935. old_size = list(x.get_size())
  936. if len(repeats) > len(old_size):
  937. old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size
  938. x = view(x, list(old_size))
  939. assert len(repeats) == len(x.get_size())
  940. new_size = list(x.get_size())
  941. zero_tensor = False
  942. for i in range(len(repeats)):
  943. if repeats[i] == 0:
  944. zero_tensor = True
  945. new_size[i] = new_size[i] * repeats[i]
  946. if zero_tensor:
  947. return empty(new_size, dtype=x.get_dtype(), device=x.get_device())
  948. if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)):
  949. return clone(expand(x, new_size))
  950. x_loader: Callable[[Any], Any]
  951. def inner_fn(index):
  952. assert len(index) == len(repeats)
  953. index = list(index)
  954. for i in range(len(repeats)):
  955. if repeats[i] != 1:
  956. if old_size[i] == 1:
  957. index[i] = sympy.S.Zero
  958. else:
  959. index[i] = ModularIndexing(index[i], 1, old_size[i])
  960. return x_loader(index)
  961. if not free_unbacked_symbols(old_size) and not free_unbacked_symbols(new_size):
  962. old_size_product = V.graph.sizevars.size_hint_or_throw(sympy_product(old_size))
  963. if old_size_product > 0:
  964. # maybe realize the input but skip for unbacked symints since it'll
  965. # choke on the size hint.
  966. x.mark_reuse(
  967. V.graph.sizevars.size_hint_or_throw(sympy_product(new_size))
  968. // old_size_product
  969. )
  970. x_loader = x.make_loader()
  971. return Pointwise.create(
  972. device=x.get_device(),
  973. dtype=x.get_dtype(),
  974. inner_fn=inner_fn,
  975. ranges=list(new_size),
  976. )
  977. @register_lowering(aten._unsafe_view, type_promotion_kind=None)
  978. @register_lowering(aten.view, type_promotion_kind=None)
  979. @register_lowering(aten.reshape, type_promotion_kind=None)
  980. def view(x: TensorBox, sizes: Sequence[sympy.Expr]) -> TensorBox:
  981. return TensorBox(View.create(x.data, sizes))
  982. @register_lowering(aten.permute, type_promotion_kind=None)
  983. def permute(x, dims):
  984. assert isinstance(x, TensorBox)
  985. assert isinstance(dims, (list, tuple))
  986. return TensorBox(PermuteView.create(x.data, tuple(dims)))
  987. @register_lowering(aten.slice, type_promotion_kind=None)
  988. def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
  989. assert isinstance(x, TensorBox)
  990. dim = _validate_dim(x, dim, 0)
  991. return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp))
  992. @register_lowering(aten.as_strided, type_promotion_kind=None)
  993. def as_strided(x, size, stride, storage_offset=None):
  994. if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
  995. # as_strided ignores views
  996. x = x.data.unwrap_view()
  997. x.realize()
  998. if not ir.is_storage_and_layout(x):
  999. raise NotImplementedError(f"unrealized as_strided({x}, ...)")
  1000. storage, old_layout = ir.as_storage_and_layout(x)
  1001. new_layout = ir.FixedLayout(
  1002. old_layout.device,
  1003. old_layout.dtype,
  1004. [sympy.expand(s) for s in size],
  1005. [sympy.expand(s) for s in stride],
  1006. sympy.expand(storage_offset or 0),
  1007. )
  1008. return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout))
  1009. @register_lowering(aten.as_strided_, type_promotion_kind=None)
  1010. def as_strided_(x, size, stride, storage_offset=None):
  1011. assert isinstance(x, TensorBox)
  1012. x.data = as_strided(x, size, stride, storage_offset).data
  1013. return x
  1014. @register_lowering(aten.as_strided_copy, type_promotion_kind=None)
  1015. def as_strided_copy(x, size, stride, storage_offset=None):
  1016. result = as_strided(x, size, stride, storage_offset)
  1017. return clone(result)
  1018. def pointwise_cat(inputs, dim=0):
  1019. # (inclusive, exclusive)
  1020. inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = []
  1021. prev_end = 0
  1022. for inp in inputs:
  1023. inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
  1024. prev_end = inputs_ranges[-1][-1] # type: ignore[assignment]
  1025. inputs_loaders = [inp.make_loader() for inp in inputs]
  1026. def inner_fn(idx):
  1027. idx_dim = ops.index_expr(idx[dim], torch.int64)
  1028. masks = []
  1029. masked_loads = []
  1030. for i in range(len(inputs)):
  1031. start = (
  1032. ops.constant(0, torch.int64)
  1033. if i == 0
  1034. else ops.index_expr(inputs_ranges[i][0], torch.int64)
  1035. )
  1036. end = ops.index_expr(inputs_ranges[i][1], torch.int64)
  1037. start_cond = ops.ge(idx_dim, start)
  1038. end_cond = ops.lt(idx_dim, end)
  1039. if i == 0:
  1040. mask = end_cond
  1041. elif i == len(inputs) - 1:
  1042. mask = start_cond
  1043. else:
  1044. mask = ops.and_(start_cond, end_cond)
  1045. masks.append(mask)
  1046. idx_load = list(idx)
  1047. # if we're concatting [4], [2]
  1048. # when we index the second tensor for 5 we want to index 5 - 4
  1049. # Use Identity to prevent expansion of index * stride to keep expression
  1050. # in same int bitwidth as shape
  1051. idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0])
  1052. masked_loads.append(
  1053. ops.masked(
  1054. mask,
  1055. lambda: inputs_loaders[i](idx_load),
  1056. 0.0, # this value should be unused
  1057. ),
  1058. )
  1059. next_val = masked_loads[-1]
  1060. for i in range((len(inputs)) - 2, -1, -1):
  1061. next_val = ops.where(
  1062. masks[i],
  1063. masked_loads[i],
  1064. next_val,
  1065. )
  1066. return next_val
  1067. new_size = list(inputs[0].get_size())
  1068. new_size[dim] = inputs_ranges[-1][-1]
  1069. return Pointwise.create(
  1070. device=inputs[0].get_device(),
  1071. dtype=inputs[0].get_dtype(),
  1072. inner_fn=inner_fn,
  1073. ranges=new_size,
  1074. )
  1075. @register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None)
  1076. def quantized_decomposed_quantize_per_channel(
  1077. input: TensorBox,
  1078. scales: TensorBox,
  1079. zero_points: TensorBox,
  1080. axis: int,
  1081. quant_min: int,
  1082. quant_max: int,
  1083. dtype: torch.dtype,
  1084. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1085. assert len(scales.get_size()) == 1, "expect scales 1 dim"
  1086. assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
  1087. if input.get_dtype() == torch.bfloat16:
  1088. input = to_dtype(input, torch.float32)
  1089. assert input.get_dtype() == torch.float32, (
  1090. f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
  1091. )
  1092. assert axis < len(input.get_size()), (
  1093. f"Expecting axis to be < {len(input.get_size())}"
  1094. )
  1095. input_loader = input.make_loader()
  1096. scales_loader = scales.make_loader()
  1097. zero_points_loader = zero_points.make_loader()
  1098. def inner_fn(idx):
  1099. channel_idx = (idx[axis],)
  1100. input = input_loader(idx)
  1101. scale = scales_loader(channel_idx)
  1102. zero_point = zero_points_loader(channel_idx)
  1103. qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
  1104. if scales.dtype != torch.float32:
  1105. scale = ops.to_dtype(scale, torch.float32)
  1106. if zero_points.dtype != torch.int32:
  1107. zero_point = ops.to_dtype(zero_point, torch.int32)
  1108. inv_scale = ops.reciprocal(scale)
  1109. val = ops.round(input * inv_scale) + zero_point
  1110. clamped = ops.maximum(qmin, ops.minimum(qmax, val))
  1111. return ops.to_dtype(clamped, dtype)
  1112. return Pointwise.create(
  1113. device=input.get_device(),
  1114. dtype=dtype,
  1115. inner_fn=inner_fn,
  1116. ranges=input.get_size(),
  1117. )
  1118. def _assert_async(cond, msg):
  1119. cond.realize()
  1120. cond = to_dtype(cond, torch.bool)
  1121. def inner_fn(index):
  1122. with ir.ComputedBuffer.force_realize():
  1123. return ops.device_assert_async(cond.make_loader()(index), msg)
  1124. assertion_op = Pointwise.create(
  1125. device=cond.get_device(),
  1126. dtype=cond.get_dtype(),
  1127. inner_fn=inner_fn,
  1128. ranges=list(cond.get_size()),
  1129. )
  1130. assertion_op.realize()
  1131. return assertion_op
  1132. @register_lowering(aten._assert_async.msg)
  1133. def lower_assert_async(cond, msg):
  1134. return _assert_async(cond, msg)
  1135. @register_lowering(aten._functional_assert_async.msg)
  1136. def lower_assert_functional_async(cond, msg):
  1137. return _assert_async(cond, msg)
  1138. @register_lowering(
  1139. quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
  1140. )
  1141. def quantized_decomposed_dequantize_per_channel(
  1142. input: TensorBox,
  1143. scales: TensorBox,
  1144. zero_points: TensorBox,
  1145. axis: int,
  1146. quant_min: int,
  1147. quant_max: int,
  1148. dtype: torch.dtype,
  1149. *,
  1150. out_dtype: Optional[torch.dtype] = None,
  1151. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1152. assert len(scales.get_size()) == 1, "expect scales 1 dim"
  1153. assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
  1154. assert input.get_dtype() == dtype, (
  1155. f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
  1156. )
  1157. assert axis < len(input.get_size()), (
  1158. f"Expecting axis to be < {len(input.get_size())}"
  1159. )
  1160. if out_dtype is None:
  1161. out_dtype = torch.float32
  1162. input_loader = input.make_loader()
  1163. scales_loader = scales.make_loader()
  1164. zero_points_loader = zero_points.make_loader()
  1165. def inner_fn(idx):
  1166. channel_idx = (idx[axis],)
  1167. input = input_loader(idx)
  1168. scale = scales_loader(channel_idx)
  1169. zero_point = zero_points_loader(channel_idx)
  1170. if scales.dtype != torch.float32:
  1171. scale = ops.to_dtype(scale, torch.float32)
  1172. if zero_points.dtype != torch.float32:
  1173. zero_point = ops.to_dtype(zero_point, torch.float32)
  1174. val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
  1175. val = ops.to_dtype(val, out_dtype)
  1176. return val
  1177. return Pointwise.create(
  1178. device=input.get_device(),
  1179. dtype=out_dtype,
  1180. inner_fn=inner_fn,
  1181. ranges=input.get_size(),
  1182. )
  1183. @register_lowering(
  1184. quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None
  1185. )
  1186. def quantized_decomposed_quantize_per_tensor_default(
  1187. input: TensorBox,
  1188. scale: float,
  1189. zero_point: int,
  1190. quant_min: int,
  1191. quant_max: int,
  1192. dtype: torch.dtype,
  1193. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1194. if input.get_dtype() == torch.bfloat16:
  1195. input = to_dtype(input, torch.float32)
  1196. assert input.get_dtype() == torch.float32, (
  1197. f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
  1198. )
  1199. input_loader = input.make_loader()
  1200. def inner_fn(idx, scale, zero_point):
  1201. input = input_loader(idx)
  1202. inv_scale, zero_point = _create_constants(
  1203. 1.0 / scale, zero_point, dtype=torch.float32
  1204. )
  1205. val = ops.round(input * inv_scale) + zero_point
  1206. qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
  1207. clamped = ops.minimum(ops.maximum(val, qmin), qmax)
  1208. return ops.to_dtype(clamped, dtype)
  1209. return Pointwise.create(
  1210. device=input.get_device(),
  1211. dtype=dtype,
  1212. inner_fn=functools.partial(
  1213. inner_fn, scale=float(scale), zero_point=int(zero_point)
  1214. ),
  1215. ranges=input.get_size(),
  1216. )
  1217. @register_lowering(
  1218. quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None
  1219. )
  1220. def quantized_decomposed_dequantize_per_tensor_default(
  1221. input: TensorBox,
  1222. scale: float,
  1223. zero_point: int,
  1224. quant_min: int,
  1225. quant_max: int,
  1226. dtype: torch.dtype,
  1227. *,
  1228. out_dtype: Optional[torch.dtype] = None,
  1229. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1230. assert input.get_dtype() == dtype, (
  1231. f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
  1232. )
  1233. if out_dtype is None:
  1234. out_dtype = torch.float32
  1235. input_loader = input.make_loader()
  1236. def inner_fn(idx, scale, zero_point):
  1237. input = input_loader(idx)
  1238. scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32)
  1239. val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
  1240. val = ops.to_dtype(val, out_dtype)
  1241. return val
  1242. return Pointwise.create(
  1243. device=input.get_device(),
  1244. dtype=out_dtype,
  1245. inner_fn=functools.partial(
  1246. inner_fn, scale=float(scale), zero_point=int(zero_point)
  1247. ),
  1248. ranges=input.get_size(),
  1249. )
  1250. @register_lowering(
  1251. quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None
  1252. )
  1253. def quantized_decomposed_quantize_per_tensor_tensor(
  1254. input: TensorBox,
  1255. scale: TensorBox,
  1256. zero_point: TensorBox,
  1257. quant_min: int,
  1258. quant_max: int,
  1259. dtype: torch.dtype,
  1260. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1261. if input.get_dtype() == torch.bfloat16:
  1262. input = to_dtype(input, torch.float32)
  1263. assert input.get_dtype() == torch.float32, (
  1264. f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
  1265. )
  1266. assert len(scale.get_size()) == 0 or (
  1267. len(scale.get_size()) == 1 and scale.get_size()[0] == 1
  1268. ), "expect scale as scalar tensor"
  1269. assert len(zero_point.get_size()) == 0 or (
  1270. len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
  1271. ), "expect zero_point as scalar tensor"
  1272. input_loader = input.make_loader()
  1273. scale_loader = scale.make_loader()
  1274. zero_point_loader = zero_point.make_loader()
  1275. def inner_fn(idx):
  1276. input = input_loader(idx)
  1277. _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
  1278. _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
  1279. if scale.dtype != torch.float32:
  1280. _scale = ops.to_dtype(_scale, torch.float32)
  1281. if zero_point.dtype != torch.float32:
  1282. _zero_point = ops.to_dtype(_zero_point, torch.float32)
  1283. val = ops.round(input * ops.reciprocal(_scale)) + _zero_point
  1284. qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
  1285. clamped = ops.minimum(ops.maximum(val, qmin), qmax)
  1286. return ops.to_dtype(clamped, dtype)
  1287. return Pointwise.create(
  1288. device=input.get_device(),
  1289. dtype=dtype,
  1290. inner_fn=inner_fn,
  1291. ranges=input.get_size(),
  1292. )
  1293. @register_lowering(
  1294. quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None
  1295. )
  1296. def quantized_decomposed_dequantize_per_tensor_tensor(
  1297. input: TensorBox,
  1298. scale: TensorBox,
  1299. zero_point: TensorBox,
  1300. quant_min: int,
  1301. quant_max: int,
  1302. dtype: torch.dtype,
  1303. *,
  1304. out_dtype: Optional[torch.dtype] = None,
  1305. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1306. assert len(scale.get_size()) == 0 or (
  1307. len(scale.get_size()) == 1 and scale.get_size()[0] == 1
  1308. ), "expect scale as scalar tensor"
  1309. assert len(zero_point.get_size()) == 0 or (
  1310. len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
  1311. ), "expect zero_point as scalar tensor"
  1312. assert input.get_dtype() == dtype, (
  1313. f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
  1314. )
  1315. if out_dtype is None:
  1316. out_dtype = torch.float32
  1317. input_loader = input.make_loader()
  1318. scale_loader = scale.make_loader()
  1319. zero_point_loader = zero_point.make_loader()
  1320. def inner_fn(idx):
  1321. input = input_loader(idx)
  1322. _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
  1323. _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
  1324. if scale.dtype != torch.float32:
  1325. _scale = ops.to_dtype(_scale, torch.float32)
  1326. if zero_point.dtype != torch.float32:
  1327. _zero_point = ops.to_dtype(_zero_point, torch.float32)
  1328. val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale
  1329. val = ops.to_dtype(val, out_dtype)
  1330. return val
  1331. return Pointwise.create(
  1332. device=input.get_device(),
  1333. dtype=out_dtype,
  1334. inner_fn=inner_fn,
  1335. ranges=input.get_size(),
  1336. )
  1337. @register_lowering(aten.cat)
  1338. def cat(inputs, dim=0):
  1339. cpu_device = inputs[0].get_device().type == "cpu"
  1340. if cpu_device and all(
  1341. input.get_dtype() in [torch.int8, torch.uint8] for input in inputs
  1342. ):
  1343. # TODO <leslie> Remove this fallback when we support vectorization
  1344. # code gen with uint8 data type directly.
  1345. for input in inputs:
  1346. input.realize()
  1347. if all(len(input.get_size()) == 4 for input in inputs):
  1348. inputs, _ = require_channels_last(aten.cat, *inputs)
  1349. return fallback_handler(aten.cat.default)(inputs, dim)
  1350. if len(inputs) == 1:
  1351. return clone(inputs[0])
  1352. dim = _validate_dim(inputs[0], dim, 0)
  1353. dtype = get_promoted_dtype(
  1354. *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  1355. )
  1356. inputs = [to_dtype(inp, dtype) for inp in inputs]
  1357. def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode:
  1358. if isinstance(x, TensorBox):
  1359. if isinstance(x.data, ir.BaseView):
  1360. return x.data.unwrap_view()
  1361. else:
  1362. return x.data
  1363. if isinstance(x, ir.StorageBox):
  1364. return x.data
  1365. return x
  1366. def is_reduction(t):
  1367. return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction)
  1368. def can_fuse_reduction(t):
  1369. if isinstance(t, (TensorBox, ir.StorageBox)):
  1370. return can_fuse_reduction(unwrap_tensor(t))
  1371. return (
  1372. is_reduction(t)
  1373. or isinstance(t, ir.Pointwise)
  1374. and any(
  1375. can_fuse_reduction(V.graph.get_buffer(read))
  1376. for read in t.get_read_names()
  1377. )
  1378. )
  1379. # fusing reducutions into computed concat buffer can cause regressions.
  1380. fusable_reduction = any(can_fuse_reduction(t) for t in inputs)
  1381. def should_lower_cat_input(x) -> bool:
  1382. # Unrealized inputs will not be storage and layouts, and we dont want to realize
  1383. # them in case we want to fuse
  1384. if ir.is_storage_and_layout(x):
  1385. storage, _ = ir.as_storage_and_layout(x, freeze=False)
  1386. return not ir.ConcatKernel.can_realize_into_without_copy(storage)
  1387. if isinstance(x, (TensorBox, ir.StorageBox)):
  1388. return should_lower_cat_input(unwrap_tensor(x))
  1389. if isinstance(x, ir.Pointwise):
  1390. return True
  1391. return False
  1392. if config.force_pointwise_cat:
  1393. return pointwise_cat(inputs, dim)
  1394. # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it.
  1395. # We will revisit this later after enabling vectorization on index_expr.
  1396. if cpu_device:
  1397. return TensorBox(ir.ConcatKernel.create(inputs, dim))
  1398. def op_count(x):
  1399. if isinstance(x, (TensorBox, ir.StorageBox)):
  1400. return op_count(unwrap_tensor(x))
  1401. # this will correspond to a direct memory read
  1402. if not isinstance(x, ir.Pointwise):
  1403. return 0
  1404. count = x.inner_fn_opcount().num_ops
  1405. for read in x.get_read_names():
  1406. count += op_count(V.graph.get_buffer(read))
  1407. return count
  1408. # as of inputs increase, possibility for register spilling also increases
  1409. # past a certain threshold of inputs we only fuse if the if the input kernels
  1410. # are simple
  1411. # not sure if we want to expose to users via config since logic may change in future
  1412. MAX_COMPLEX_POINTWISE_CAT = 8
  1413. MAX_SIMPLE_OP_COUNT = 2
  1414. def additional_pointwise_ops(op: torch._ops.OpOverload):
  1415. return op in (aten.cat.default, aten.constant_pad_nd.default)
  1416. if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or (
  1417. (len(inputs) <= config.max_pointwise_cat_inputs)
  1418. and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
  1419. ):
  1420. pointwise_uses = all(
  1421. is_pointwise_use(use, additional_pointwise_ops)
  1422. for use in V.current_node.users
  1423. )
  1424. # fuse in case we will be used in a pointwise node, and there are any inputs we
  1425. # we can prevent materialization of.
  1426. fuse_pointwise_use = (
  1427. any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses
  1428. )
  1429. # horizontal fuse in case all inputs will require a copy kernel anyway.
  1430. # only horizontally fuse pointwise kernels
  1431. horizontal_fuse_cat = all(
  1432. should_lower_cat_input(inp) for inp in inputs
  1433. ) and not any(can_fuse_reduction(t) for t in inputs)
  1434. if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction):
  1435. return pointwise_cat(inputs, dim)
  1436. return TensorBox(ir.ConcatKernel.create(inputs, dim))
  1437. @register_lowering(aten.diagonal, type_promotion_kind=None)
  1438. def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
  1439. original_shape = input.get_size()
  1440. num_dims = len(original_shape)
  1441. dim1 = canonicalize_dim(idx=dim1, rank=num_dims)
  1442. dim2 = canonicalize_dim(idx=dim2, rank=num_dims)
  1443. check(
  1444. dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
  1445. )
  1446. offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0))
  1447. if offset_negative:
  1448. diag_size = V.graph.sizevars.evaluate_max(
  1449. V.graph.sizevars.evaluate_min(
  1450. original_shape[dim1] + offset, original_shape[dim2]
  1451. ),
  1452. 0, # type: ignore[arg-type]
  1453. )
  1454. else:
  1455. diag_size = V.graph.sizevars.evaluate_max(
  1456. V.graph.sizevars.evaluate_min(
  1457. original_shape[dim1], original_shape[dim2] - offset
  1458. ),
  1459. 0, # type: ignore[arg-type]
  1460. )
  1461. base_idx = (0, 0)
  1462. if offset_negative:
  1463. base_idx = (-offset, 0)
  1464. else:
  1465. base_idx = (0, offset)
  1466. sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)]
  1467. sizes.append(diag_size)
  1468. def reindexer(idx):
  1469. diag_idx = idx[-1]
  1470. original_idx = [0] * len(original_shape)
  1471. cur_dim = 0
  1472. for d in range(num_dims):
  1473. if d == dim1:
  1474. original_idx[d] = diag_idx + base_idx[0]
  1475. elif d == dim2:
  1476. original_idx[d] = diag_idx + base_idx[1]
  1477. else:
  1478. original_idx[d] = idx[cur_dim]
  1479. cur_dim += 1
  1480. assert cur_dim == len(original_shape) - 2
  1481. return original_idx
  1482. return TensorBox(ir.GenericView.create(input, sizes, reindexer))
  1483. @register_lowering(aten.diagonal_copy, type_promotion_kind=None)
  1484. def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
  1485. return clone(diagonal(input, offset, dim1, dim2))
  1486. @register_lowering(aten.diagonal_scatter, type_promotion_kind=None)
  1487. def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
  1488. output = clone(input)
  1489. target = diagonal(output, offset, dim1, dim2)
  1490. mutate_to(target, src)
  1491. return output
  1492. @register_lowering(aten.select, type_promotion_kind=None)
  1493. def select(x, dim, idx):
  1494. idx = sympy.expand(idx)
  1495. size = sympy.expand(x.get_size()[dim])
  1496. actual_index = None
  1497. if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)):
  1498. actual_index = idx + size
  1499. elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)):
  1500. actual_index = idx
  1501. if actual_index is not None:
  1502. if has_free_unbacked_symbols(idx):
  1503. # Inductor could generate incorrect views for tensors with unbacked symbols here;
  1504. # Squeeze operations are translated to views, resulting in incorrect strides.
  1505. # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
  1506. # we use as_strided instead.
  1507. # Removing this branch will cause test_unbacked_select_index_with_check to fail.
  1508. new_size = x.get_size()
  1509. new_stride = x.get_stride()
  1510. new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
  1511. del new_size[dim]
  1512. del new_stride[dim]
  1513. return as_strided(x, new_size, new_stride, new_storage_offset)
  1514. else:
  1515. slice_result = slice_(x, dim, actual_index, actual_index + 1)
  1516. return squeeze(slice_result, dim)
  1517. # Unbacked Semantics:
  1518. # When the index idx is unbacked (e.g., u0), we compute the index dynamically
  1519. # during the lowering of the select operation using DynamicSelectStorageOffset.
  1520. unbacked_bindings = resolve_unbacked_bindings(
  1521. V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
  1522. )
  1523. assert unbacked_bindings is not None
  1524. assert len(unbacked_bindings) == 1, unbacked_bindings
  1525. unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
  1526. new_size = x.get_size()
  1527. new_stride = x.get_stride()
  1528. new_storage_offset = unbacked_offset_sym
  1529. buffer = ir.DynamicSelectStorageOffset(
  1530. unbacked_offset_sym,
  1531. idx,
  1532. x.get_layout().offset,
  1533. new_stride[dim],
  1534. x.get_size()[dim],
  1535. )
  1536. buffer.name = V.graph.register_buffer(buffer)
  1537. V.graph.register_operation(buffer)
  1538. del new_size[dim]
  1539. del new_stride[dim]
  1540. return as_strided(x, new_size, new_stride, new_storage_offset)
  1541. @register_lowering(aten.split, type_promotion_kind=None)
  1542. def split(x, sizes, dim=0):
  1543. dim = _validate_dim(x, dim, 0)
  1544. sizes_ = sizes
  1545. # If sizes is an integer (or a SymInt), we turn it into a list of sizes
  1546. # by computing what the actual size of each chunk should be.
  1547. if not isinstance(sizes, (list, tuple)):
  1548. x_size = x.get_size()[dim]
  1549. chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes))
  1550. sizes_ = [sizes] * chunks
  1551. # The last chunk might have a smaller size than the rest.
  1552. sizes_[-1] = x_size - (chunks - 1) * sizes
  1553. # From this point, we assume that the sum of the sizes of all chunks
  1554. # equals the size of the base tensor.
  1555. result = []
  1556. start = 0
  1557. for size in sizes_:
  1558. end = start + size
  1559. # No need for clamping here, since we compute the exact
  1560. # start and end values.
  1561. result.append(slice_(x, dim, start, end, clamp=False))
  1562. start = end
  1563. return result
  1564. @register_lowering(aten.split_with_sizes, type_promotion_kind=None)
  1565. def split_with_sizes(x, sizes, dim=0):
  1566. return split(x, sizes, dim)
  1567. @register_lowering(aten.unbind, type_promotion_kind=None)
  1568. def unbind(x, dim=0):
  1569. dim = _validate_dim(x, dim, 0)
  1570. x_size = V.graph.sizevars.guard_int(x.get_size()[dim])
  1571. result = [select(x, dim, i) for i in range(x_size)]
  1572. return result
  1573. @register_lowering(aten.unfold, type_promotion_kind=None)
  1574. def unfold(x, dimension, size, step):
  1575. sizes = x.get_size()
  1576. ndim = len(sizes)
  1577. dim = canonicalize_dim(ndim, dimension)
  1578. if ndim == 0:
  1579. return slice_(unsqueeze(x, 0), end=size)
  1580. dim_size = sizes[dim]
  1581. sizevars = V.graph.sizevars
  1582. sizevars.check_leq(size, dim_size)
  1583. sizevars.check_lt(0, step) # type: ignore[arg-type]
  1584. new_dim_size = FloorDiv(dim_size - size, step) + 1
  1585. if sizevars.size_hint_or_throw(dim_size) > 0:
  1586. x.mark_reuse(
  1587. sizevars.size_hint_or_throw(CeilDiv(new_dim_size * size, dim_size))
  1588. )
  1589. out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size]
  1590. def reindexer(idx):
  1591. dim_idx = idx[-1] + idx[dim] * step
  1592. return (*idx[:dim], dim_idx, *idx[dim + 1 : -1])
  1593. return TensorBox(ir.GenericView.create(x, out_size, reindexer))
  1594. @register_lowering(aten.unsqueeze, type_promotion_kind=None)
  1595. def unsqueeze(x, dim):
  1596. dim = _validate_dim(x, dim, 1)
  1597. new_shape = list(x.get_size())
  1598. new_shape.insert(dim, sympy.S.One)
  1599. return view(x, new_shape)
  1600. @register_lowering(aten.unsqueeze_, type_promotion_kind=None)
  1601. def unsqueeze_(x, dim):
  1602. val = unsqueeze(x, dim)
  1603. assert isinstance(x, TensorBox)
  1604. assert isinstance(val, TensorBox)
  1605. x.data = val.data
  1606. return x
  1607. def _validate_dim(x, dim, offset=0):
  1608. dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim))
  1609. ndim = len(x.get_size())
  1610. if dim < 0:
  1611. dim += ndim + offset
  1612. assert 0 <= dim < ndim + offset
  1613. return dim
  1614. @register_lowering(aten.glu)
  1615. def glu(x, dim=-1):
  1616. dim = _validate_dim(x, dim, 0)
  1617. # TODO: don't guard on static shape here
  1618. new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2
  1619. a = slice_(x, dim, 0, new_len)
  1620. b = slice_(x, dim, new_len, new_len * 2)
  1621. return mul(a, sigmoid(b))
  1622. def fallback_handler(kernel, add_to_fallback_set=True):
  1623. if add_to_fallback_set:
  1624. fallbacks.add(kernel)
  1625. def handler(*args, **kwargs):
  1626. def wrap_tensors(x):
  1627. return TensorBox.create(x) if isinstance(x, ir.IRNode) else x
  1628. return pytree.tree_map(
  1629. wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs)
  1630. )
  1631. # This lets us detect that a lowering is a fallback handler.
  1632. handler._is_fallback_handler = True # type: ignore[attr-defined]
  1633. return handler
  1634. @functools.cache
  1635. def _warn_complex_not_supported():
  1636. warnings.warn(
  1637. "Torchinductor does not support code generation for complex operators. Performance may be worse than eager."
  1638. )
  1639. # There are some types (CPU) which we accept as input but not as
  1640. # output.
  1641. def unsupported_input_tensor(t: torch.Tensor, node=None):
  1642. "Do not support reading or writing to this tensor"
  1643. if t.is_complex():
  1644. # Complex views are supported with IR ComplexView
  1645. _warn_complex_not_supported()
  1646. return True
  1647. if t.is_meta:
  1648. return True
  1649. if t.dtype == torch.float8_e8m0fnu:
  1650. if not node:
  1651. return True
  1652. # allow bitcast, views, memory movement, but not arithmetic
  1653. # TODO: delete once triton adds native support
  1654. return not (
  1655. isinstance(node.target, torch._ops.OpOverload)
  1656. and node.target
  1657. in (
  1658. aten.view.dtype,
  1659. aten.cat.default,
  1660. aten.clone.default,
  1661. aten._scaled_mm.default,
  1662. )
  1663. or (isinstance(node.target, torch._ops.OpOverload) and is_view(node.target))
  1664. )
  1665. return False
  1666. def unsupported_output_tensor(t: torch.Tensor, node=None):
  1667. "Do not support writing tensor but can read from it"
  1668. supported_complex_views = (
  1669. aten.view.dtype,
  1670. torch.ops.prims.convert_element_type.default,
  1671. )
  1672. if node is not None and node.target in supported_complex_views and t.is_complex():
  1673. return False
  1674. if unsupported_input_tensor(t, node):
  1675. return True
  1676. return t.is_cpu and config.disable_cpp_codegen
  1677. def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True):
  1678. # Custom fallback lowering
  1679. if node.target is aten.view_as_complex.default:
  1680. return False
  1681. if node.op == "placeholder":
  1682. return False
  1683. # We should be able to remove this special case once `disable_cpp_codegen` is killed.
  1684. if node.target is aten.lift_fresh_copy.default:
  1685. return False
  1686. def check_skip_condition(inp_out_node, is_output):
  1687. if not isinstance(inp_out_node, torch.fx.Node):
  1688. return False
  1689. if "val" not in inp_out_node.meta:
  1690. return False
  1691. for meta in pytree.tree_leaves(inp_out_node.meta["val"]):
  1692. if not isinstance(meta, torch._subclasses.FakeTensor):
  1693. continue
  1694. if is_output:
  1695. if unsupported_output_tensor(meta, node):
  1696. return True
  1697. else:
  1698. if unsupported_input_tensor(meta, node):
  1699. return True
  1700. return False
  1701. # only skip codegen if there is a cpu output, not input
  1702. for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs):
  1703. if check_skip_condition(arg, is_output=False):
  1704. return True
  1705. return check_skip_condition(node, is_output=True)
  1706. def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
  1707. assert op not in decompositions or override_decomp, (
  1708. f"both a fallback and a decomp for same op: {op}"
  1709. )
  1710. if (
  1711. warn
  1712. and bool(os.getenv("CI"))
  1713. and get_decompositions([op])
  1714. # if fallback_random, we allow not decomposing random
  1715. and not (
  1716. config.fallback_random
  1717. and op in torch._decomp.decompositions_for_rng.extra_random_decomps
  1718. )
  1719. and not override_decomp
  1720. ):
  1721. # Note: 'warn' is holdover from when this was a warning, but for ops that previously
  1722. # set warn=False we do not want a CI error.
  1723. # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not
  1724. # likely to be triggered preferentially on one CI config over another.
  1725. if torch._dynamo.config.suppress_errors:
  1726. torch._dynamo.config.suppress_errors = False
  1727. log.warning(
  1728. "A make_fallback error occurred in suppress_errors config,"
  1729. " and suppress_errors is being disabled to surface it."
  1730. )
  1731. raise AssertionError(
  1732. f"make_fallback({op}): a decomposition exists, we should switch to it."
  1733. " To fix this error, either add a decomposition to core_aten_decompositions (preferred)"
  1734. " or inductor_decompositions, and delete the corresponding `make_fallback` line."
  1735. " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.",
  1736. )
  1737. def register_fallback(op_overload):
  1738. add_needs_realized_inputs(op_overload)
  1739. if layout_constraint is not None:
  1740. add_layout_constraint(op_overload, layout_constraint)
  1741. return register_lowering(op_overload, type_promotion_kind=None)(
  1742. fallback_handler(op_overload)
  1743. )
  1744. if isinstance(op, torch._ops.OpOverloadPacket):
  1745. for ol in op.overloads():
  1746. op_overload = getattr(op, ol)
  1747. register_fallback(op_overload)
  1748. elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
  1749. register_fallback(op)
  1750. else:
  1751. raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}")
  1752. def philox_rand_offset(shape):
  1753. """
  1754. TorchInductor offset calculation differs from PyTorch eager offset
  1755. calculation for random ops (tl.rand vs torch.rand). In future, we should
  1756. strive for same impl for tl.rand and torch.rand.
  1757. """
  1758. numel = 1
  1759. for s in shape:
  1760. numel = numel * s
  1761. return tensor(numel, dtype=torch.int64)
  1762. @register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None)
  1763. def philox_rand(size, seed, offset, stride, device, dtype):
  1764. # stride arg is optional and will be used in future for distributed random
  1765. # ops. Currently, its unused.
  1766. random_pos = ir.FixedLayout(
  1767. device,
  1768. dtype,
  1769. size,
  1770. ir.FlexibleLayout.contiguous_strides(size),
  1771. ).make_indexer()
  1772. seed_loader = seed.make_loader()
  1773. offset_loader = offset.make_loader()
  1774. def inner_fn(index):
  1775. # Both seed and offset in the philox_rand op are tensors.
  1776. # torch seed and offsets are of type int64, but tl.rand accepts int32
  1777. seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32)
  1778. offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32)
  1779. # Get the offset'd position
  1780. rand_index_expr = ops.add(
  1781. ops.index_expr(random_pos(index), torch.int32), offset_index_expr
  1782. )
  1783. result = ops.rand(
  1784. seed_index_expr,
  1785. rand_index_expr,
  1786. )
  1787. return ops.to_dtype(result, dtype)
  1788. random_values_node = Pointwise.create(
  1789. device=device,
  1790. dtype=dtype,
  1791. inner_fn=inner_fn,
  1792. ranges=list(size),
  1793. )
  1794. offset_node = philox_rand_offset(size)
  1795. return random_values_node, offset_node
  1796. @register_lowering(aten.native_dropout, type_promotion_kind=None)
  1797. def native_dropout(x, p, train):
  1798. if config.fallback_random:
  1799. return pytree.tree_map(
  1800. TensorBox.create,
  1801. ir.FallbackKernel.create(aten.native_dropout.default, x, p, train),
  1802. )
  1803. else:
  1804. raise AssertionError("should be handled in replace_random.py")
  1805. @register_lowering(aten.bernoulli_, type_promotion_kind=None)
  1806. def bernoulli_(x, *args):
  1807. assert config.fallback_random or x.get_device() == torch.device("cpu"), (
  1808. "this should be handled in decomps unless config.fallback_random or the device is CPU"
  1809. )
  1810. x.realize()
  1811. op_overload = (
  1812. aten.bernoulli_.float
  1813. if len(args) == 0 or isinstance(args[0], float)
  1814. else aten.bernoulli_.Tensor
  1815. )
  1816. ir.InplaceBernoulliFallback(op_overload, x, *args)
  1817. return x
  1818. @register_lowering(aten.bernoulli.p, type_promotion_kind=None)
  1819. def bernoulli_p(x, *args):
  1820. assert config.fallback_random or x.get_device() == torch.device("cpu"), (
  1821. "this should be handled in decomps unless config.fallback_random or the device is CPU"
  1822. )
  1823. return bernoulli_(clone(x), *args)
  1824. # This shouldn't be called in general
  1825. @register_lowering(aten._foobar)
  1826. def _foobar(_):
  1827. raise AssertionError
  1828. @functools.lru_cache(1)
  1829. def _warn_triton_random(salt):
  1830. log.info("using triton random, expect difference from eager")
  1831. def warn_triton_random():
  1832. # only warn once per graph
  1833. _warn_triton_random(V.graph.creation_time)
  1834. fallback_rand_default = fallback_handler(aten.rand.default)
  1835. fallback_rand_generator = fallback_handler(aten.rand.generator)
  1836. fallback_randn_default = fallback_handler(aten.randn.default)
  1837. fallback_randn_generator = fallback_handler(aten.randn.generator)
  1838. make_fallback(aten.randint)
  1839. @register_lowering(aten.rand)
  1840. def rand(*args, **kwargs):
  1841. if kwargs.get("generator", None) is not None:
  1842. return fallback_rand_generator(*args, **kwargs)
  1843. elif config.fallback_random:
  1844. kwargs.pop("generator", None)
  1845. return fallback_rand_default(*args, **kwargs)
  1846. raise AssertionError("should have been handled in replace_random.py")
  1847. @register_lowering(aten.randn)
  1848. def randn(*args, **kwargs):
  1849. if kwargs.get("generator", None) is not None:
  1850. return fallback_randn_generator(*args, **kwargs)
  1851. elif config.fallback_random:
  1852. kwargs.pop("generator", None)
  1853. return fallback_randn_default(*args, **kwargs)
  1854. raise AssertionError("should have been handled in replace_random.py")
  1855. @register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None)
  1856. def inductor_force_stride_order(input_tensor, stride):
  1857. stride_order = ir.get_stride_order(stride)
  1858. return ir.ExternKernel.require_stride_order(input_tensor, stride_order)
  1859. @register_lowering(inductor_prims.seed, type_promotion_kind=None)
  1860. def inductor_seed(device: torch.device):
  1861. raise AssertionError("should be handled in fuse_seed_creation_pass()")
  1862. @register_lowering(inductor_prims.seeds, type_promotion_kind=None)
  1863. def inductor_seeds(count, device):
  1864. warn_triton_random()
  1865. return TensorBox.create(ir.RandomSeeds(count, decode_device(device)))
  1866. @register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None)
  1867. def inductor_lookup_seed(seeds, index):
  1868. def inner_fn(_):
  1869. return ops.load_seed(seeds.get_name(), index)
  1870. return Pointwise.create(
  1871. device=seeds.get_device(),
  1872. dtype=seeds.get_dtype(),
  1873. inner_fn=inner_fn,
  1874. ranges=[],
  1875. )
  1876. @register_lowering(inductor_prims.random, type_promotion_kind=None)
  1877. def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0):
  1878. assert not config.fallback_random
  1879. assert mode in ("rand", "randn")
  1880. size = [*size]
  1881. dtype = torch.float32
  1882. device = seed.get_device_or_error()
  1883. random_pos = ir.FixedLayout(
  1884. device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
  1885. ).make_indexer()
  1886. seed_loader = seed.make_loader()
  1887. def inner_fn(index):
  1888. return getattr(ops, mode)(
  1889. seed_loader([]),
  1890. ops.index_expr(random_pos(index), torch.int32),
  1891. )
  1892. result = Pointwise.create(
  1893. device=device,
  1894. dtype=dtype,
  1895. inner_fn=inner_fn,
  1896. ranges=[*size],
  1897. )
  1898. result.realize()
  1899. return result
  1900. @register_lowering(inductor_prims.randint, type_promotion_kind=None)
  1901. def inductor_randint(
  1902. low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0
  1903. ):
  1904. assert not config.fallback_random
  1905. size = [*size]
  1906. dtype = torch.int64
  1907. device = seed.get_device_or_error()
  1908. random_pos = ir.FixedLayout(
  1909. device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
  1910. ).make_indexer()
  1911. seed_loader = seed.make_loader()
  1912. def inner_fn(index):
  1913. return ops.randint64(
  1914. seed_loader([]),
  1915. ops.index_expr(random_pos(index), torch.int32),
  1916. ops.index_expr(low, torch.int64),
  1917. ops.index_expr(high, torch.int64),
  1918. )
  1919. return Pointwise.create(
  1920. device=device,
  1921. dtype=dtype,
  1922. inner_fn=inner_fn,
  1923. ranges=[*size],
  1924. )
  1925. def _boundaries_helper(tb: TensorBox) -> tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]:
  1926. return (
  1927. tb.get_name(),
  1928. tb.get_size()[-1],
  1929. tb.get_size()[0] * tb.get_stride()[0],
  1930. tb.get_stride()[-1],
  1931. )
  1932. def _sorter_helper(tb: TensorBox) -> tuple[str, sympy.Expr]:
  1933. return tb.get_name(), tb.get_stride()[-1]
  1934. @register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None)
  1935. def searchsorted(
  1936. sorted_sequence: TensorBox,
  1937. self: TensorBox,
  1938. *,
  1939. out_int32: bool = False,
  1940. right: bool = False,
  1941. side: Optional[str] = None,
  1942. sorter: Optional[TensorBox] = None,
  1943. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1944. validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731
  1945. tb, BackendFeature.BUCKETIZE
  1946. )
  1947. if (
  1948. not validate_bucketize(sorted_sequence)
  1949. or not validate_bucketize(self)
  1950. or (sorter is not None and not validate_bucketize(sorter))
  1951. ):
  1952. return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)(
  1953. sorted_sequence,
  1954. self,
  1955. out_int32=out_int32,
  1956. right=right,
  1957. side=side,
  1958. sorter=sorter,
  1959. )
  1960. # If side is present, override the value of right if needed. This assumes that
  1961. # validation of the two options being non-contradictory is already done by the
  1962. # searchsorted meta-function.
  1963. if side is not None and side == "right":
  1964. right = True
  1965. index_dtype = torch.int32 if out_int32 else torch.int64
  1966. values_loader = self.make_loader()
  1967. # The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to
  1968. # realize it into global memory; or in other words, we can't guarantee that
  1969. # sorted_sequence.get_name() (used below) will exist unless we call
  1970. # sorted_sequence.realize().
  1971. sorted_sequence.realize()
  1972. if sorter is not None:
  1973. sorter.realize()
  1974. if len(sorted_sequence.get_size()) == 1:
  1975. def inner_fn(idx):
  1976. val = values_loader(idx)
  1977. return ops.bucketize(
  1978. val,
  1979. _boundaries_helper(sorted_sequence),
  1980. 0,
  1981. index_dtype,
  1982. right,
  1983. sorter=None if sorter is None else _sorter_helper(sorter),
  1984. sorter_indices=None if sorter is None else 0,
  1985. )
  1986. else:
  1987. def inner_fn(idx):
  1988. val = values_loader(idx)
  1989. # Get index to the beginning of the sorted sequence within a flattened
  1990. # version of the array.
  1991. def get_flattened_index(tb: TensorBox):
  1992. strides = tb.get_stride()
  1993. return ops.index_expr(
  1994. functools.reduce(
  1995. operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1]))
  1996. ),
  1997. index_dtype,
  1998. )
  1999. return ops.bucketize(
  2000. val,
  2001. _boundaries_helper(sorted_sequence),
  2002. get_flattened_index(sorted_sequence),
  2003. index_dtype,
  2004. right,
  2005. sorter=None if sorter is None else _sorter_helper(sorter),
  2006. sorter_indices=None if sorter is None else get_flattened_index(sorter),
  2007. )
  2008. device = self.get_device()
  2009. result = Pointwise.create(
  2010. device=device,
  2011. dtype=index_dtype,
  2012. inner_fn=inner_fn,
  2013. ranges=self.shape,
  2014. )
  2015. # see [NOTE: inductor bucketize realize]
  2016. result.realize()
  2017. return result
  2018. @register_lowering(
  2019. aten.bucketize, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
  2020. )
  2021. def bucketize(
  2022. input: TensorBox,
  2023. boundaries: TensorBox,
  2024. *,
  2025. out_int32: bool = False,
  2026. right: bool = False,
  2027. ):
  2028. assert len(boundaries.get_size()) == 1
  2029. if not (
  2030. V.graph.has_feature(input, BackendFeature.BUCKETIZE)
  2031. and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE)
  2032. ):
  2033. return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)(
  2034. input, boundaries, out_int32=out_int32, right=right
  2035. )
  2036. # The entire boundaries tensor needs to be used by ops.bucketize, so we
  2037. # need to realize it into global memory; or in other words, we can't
  2038. # guarantee that boundaries.get_name() (used below) will exist unless
  2039. # we call boundaries.realize().
  2040. boundaries.realize()
  2041. device = input.get_device()
  2042. input_loader = input.make_loader()
  2043. index_dtype = torch.int32 if out_int32 else torch.int64
  2044. def inner_fn(index):
  2045. val = input_loader(index)
  2046. indices = ops.bucketize(
  2047. val,
  2048. _boundaries_helper(boundaries),
  2049. 0,
  2050. index_dtype,
  2051. right,
  2052. )
  2053. return indices
  2054. result = Pointwise.create(
  2055. device=device,
  2056. dtype=index_dtype,
  2057. inner_fn=inner_fn,
  2058. ranges=input.get_size(),
  2059. )
  2060. # [NOTE: inductor bucketize realize]
  2061. # bucketize_binary_search is relatively expensive, so we don't want to re-compute
  2062. # it unnecessarily. If we run bucketize() and then broadcast the result, we don't
  2063. # want this to be fused into a large number of duplicate bucketize() computations
  2064. # for each of the elements in the result.
  2065. #
  2066. # If no broadcasting occurs, fusions can still occur in scheduler.py
  2067. result.realize()
  2068. return result
  2069. def require_dense(_, *args, **kwargs):
  2070. args, kwargs = pytree.tree_map_only(
  2071. ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs)
  2072. )
  2073. return args, kwargs
  2074. def require_contiguous(_, *args, **kwargs):
  2075. args, kwargs = pytree.tree_map_only(
  2076. ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs)
  2077. )
  2078. return args, kwargs
  2079. def require_contiguous_strides(_, *args, **kwargs):
  2080. # TODO: combine this with require_contiguous after
  2081. # https://github.com/pytorch/pytorch/pull/148235 lands.
  2082. args, kwargs = pytree.tree_map_only(
  2083. ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs)
  2084. )
  2085. return args, kwargs
  2086. def require_channels_last(_, *args, **kwargs):
  2087. args, kwargs = pytree.tree_map_only(
  2088. ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
  2089. )
  2090. return args, kwargs
  2091. def constrain_to_fake_tensor(arg, fake_arg):
  2092. if isinstance(arg, ir.IRNode):
  2093. meta_stride_expr = [
  2094. s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride()
  2095. ]
  2096. return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr)
  2097. if isinstance(arg, dict):
  2098. return {
  2099. key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg.keys()
  2100. }
  2101. elif isinstance(arg, (tuple, list)):
  2102. return type(arg)(
  2103. constrain_to_fake_tensor(a, f_a) for (a, f_a) in zip(arg, fake_arg)
  2104. )
  2105. return arg
  2106. def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs):
  2107. args = tuple(
  2108. constrain_to_fake_tensor(arg, fake_arg)
  2109. for arg, fake_arg in zip(args, fake_args)
  2110. )
  2111. kwargs = {k: constrain_to_fake_tensor(v, fake_kwargs[k]) for k, v in kwargs.items()}
  2112. return args, kwargs
  2113. def constrain_to_fx_strides(fx_node, *args, **kwargs):
  2114. def apply_constraint(arg, fx_arg):
  2115. if isinstance(arg, ir.IRNode):
  2116. stride_order = ir.get_stride_order(
  2117. fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env
  2118. )
  2119. return ir.ExternKernel.require_stride_order(arg, stride_order)
  2120. if isinstance(arg, dict):
  2121. return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()}
  2122. return arg
  2123. args = tuple(
  2124. apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
  2125. )
  2126. kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
  2127. return args, kwargs
  2128. def sdpa_constraint(fx_node, *args, **kwargs):
  2129. # sdpa requires dense last dimension]
  2130. def apply_constraint(idx, arg, fx_arg):
  2131. if not isinstance(arg, ir.IRNode):
  2132. return arg
  2133. meta_val = fx_arg.meta["val"]
  2134. meta_stride_expr = [
  2135. s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_val.stride()
  2136. ]
  2137. stride_order = ir.get_stride_order(meta_val.stride())
  2138. if stride_order and stride_order[-1] != 0:
  2139. # contiguous stride order
  2140. stride_order = list(reversed(range(len(arg.get_size()))))
  2141. if (
  2142. fx_node.target
  2143. == aten._scaled_dot_product_efficient_attention_backward.default
  2144. and idx in (0, 5)
  2145. ):
  2146. assert len(stride_order) == 4
  2147. # The 0 and 5th arguments for aten._scaled_dot_product_efficient_attention_backward.default
  2148. # are for out and gradient_out. They have to be in
  2149. # (3, 1, 2, 0) stride order. Otherwise the kernel will crash.
  2150. # Check https://github.com/pytorch/pytorch/issues/138772
  2151. stride_order = (3, 1, 2, 0)
  2152. if not meta_val.is_cuda:
  2153. return ir.ExternKernel.require_stride_order(arg, stride_order)
  2154. # This is the minimum alignment required by SDPA kernels for attention_bias.
  2155. # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
  2156. ALIGNMENT = 8
  2157. # effn_attn_fwd does requires dense last dim, not just alignment
  2158. effn_attn_fwd_bias = (
  2159. fx_node.target
  2160. == torch.ops.aten._scaled_dot_product_efficient_attention.default
  2161. and idx == 3
  2162. )
  2163. assert isinstance(arg, TensorBox)
  2164. if len(arg.get_size()) not in (3, 4):
  2165. return arg
  2166. is_aligned_tensor = ir.is_aligned_realized_tensor_hint(arg, ALIGNMENT)
  2167. if is_aligned_tensor:
  2168. return ir.try_match_insignificant_strides(
  2169. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2170. )
  2171. if (
  2172. isinstance(arg, IRNode)
  2173. and arg.maybe_get_stride() is not None
  2174. and is_aligned_tensor
  2175. ):
  2176. return ir.try_match_insignificant_strides(
  2177. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2178. )
  2179. if effn_attn_fwd_bias:
  2180. out_size = list(arg.get_size())
  2181. expanded_dims = []
  2182. # We require a dense last dimension, but the other strides
  2183. # can be expanded, which results in a smaller tensor
  2184. maybe_stride = arg.maybe_get_stride()
  2185. for i in range(len(arg.get_size()) - 1):
  2186. if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or (
  2187. maybe_stride is not None
  2188. and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0)
  2189. ):
  2190. expanded_dims.append(i)
  2191. # Now, pad strides to alignment
  2192. out_strides = [-1] * len(out_size)
  2193. out_strides[-1] = 1
  2194. stride = 1
  2195. for i in range(len(out_size) - 2, -1, -1):
  2196. if out_strides[i + 1] != 0:
  2197. stride = stride * out_size[i + 1]
  2198. # the expanded dims still need to be aligned, if they are,
  2199. # we can make them expanded by setting the stride equal to 0
  2200. if i in expanded_dims:
  2201. if V.graph.sizevars.statically_known_equals(
  2202. out_strides[i + 1] % ALIGNMENT, 0
  2203. ):
  2204. out_strides[i] = 0
  2205. continue
  2206. if not V.graph.sizevars.statically_known_equals(stride % ALIGNMENT, 0):
  2207. stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT
  2208. out_strides[i] = stride
  2209. return ir.ExternKernel.require_exact_strides(arg, out_strides)
  2210. if is_aligned_tensor:
  2211. return ir.try_match_insignificant_strides(
  2212. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2213. )
  2214. if (
  2215. isinstance(arg, IRNode)
  2216. and arg.maybe_get_stride() is not None
  2217. and is_aligned_tensor
  2218. ):
  2219. return ir.try_match_insignificant_strides(
  2220. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2221. )
  2222. def is_aligned(x):
  2223. return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0
  2224. if isinstance(arg.data, ir.BaseView):
  2225. if not is_aligned(arg):
  2226. if is_aligned(arg.unwrap_view()):
  2227. return ir.try_match_insignificant_strides(
  2228. ir.ExternKernel.realize_input(arg), meta_stride_expr
  2229. )
  2230. return ir.ExternKernel.require_stride_order(arg, stride_order)
  2231. args = tuple(
  2232. apply_constraint(idx, arg, fx_arg)
  2233. for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args))
  2234. )
  2235. kwargs = {k: apply_constraint(-1, v, fx_node.kwargs[k]) for k, v in kwargs.items()}
  2236. return args, kwargs
  2237. # WIP
  2238. make_fallback(aten._adaptive_avg_pool3d) # @isuruf
  2239. make_fallback(aten.adaptive_max_pool3d) # @isuruf
  2240. make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet
  2241. # 1) Easy
  2242. make_fallback(aten.uniform, warn=False)
  2243. make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py)
  2244. make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks
  2245. make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
  2246. make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp)
  2247. if torch.xpu.is_available():
  2248. make_fallback(
  2249. aten.embedding_dense_backward, warn=False
  2250. ) # (XPU-only and faster than decomp)
  2251. # 1.5) Easy or Impossible
  2252. make_fallback(aten._cdist_forward) # p=2 should be feasible
  2253. make_fallback(aten._cdist_backward)
  2254. # 2) Medium
  2255. make_fallback(aten._trilinear)
  2256. # 3) Difficult
  2257. # Scans
  2258. # See the discussion at
  2259. # https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19
  2260. make_fallback(aten.segment_reduce.default)
  2261. make_fallback(aten._segment_reduce_backward.default)
  2262. # Histogram (need to implement Histogram IR)
  2263. make_fallback(aten.histc)
  2264. make_fallback(aten.histogram.bin_ct)
  2265. make_fallback(aten._histogramdd_bin_edges.default)
  2266. make_fallback(aten._histogramdd_from_bin_cts.default)
  2267. # Need templated kernel
  2268. make_fallback(aten.addbmm)
  2269. make_fallback(aten._addmm_activation, warn=False)
  2270. make_fallback(aten._grouped_mm, require_dense)
  2271. # Need templated kernel. Probably impossible to write efficiently
  2272. make_fallback(aten.convolution_backward, constrain_to_fx_strides)
  2273. make_fallback(aten._cudnn_rnn, require_dense)
  2274. make_fallback(aten._cudnn_rnn_backward, require_contiguous)
  2275. # Haven't checked but sound difficult / impossible
  2276. make_fallback(aten._embedding_bag, require_contiguous)
  2277. make_fallback(aten._embedding_bag_forward_only, require_contiguous)
  2278. make_fallback(aten._embedding_bag_backward)
  2279. make_fallback(aten._embedding_bag_per_sample_weights_backward)
  2280. make_fallback(aten._embedding_bag_per_sample_weights_backward)
  2281. make_fallback(aten._fused_moving_avg_obs_fq_helper)
  2282. make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
  2283. # 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
  2284. make_fallback(aten.max_pool3d_with_indices_backward)
  2285. make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
  2286. make_fallback(aten._adaptive_avg_pool3d_backward)
  2287. make_fallback(aten.adaptive_max_pool2d_backward)
  2288. make_fallback(aten.adaptive_max_pool3d_backward)
  2289. make_fallback(aten.fractional_max_pool2d_backward)
  2290. make_fallback(aten.fractional_max_pool3d_backward)
  2291. make_fallback(aten.replication_pad1d_backward)
  2292. make_fallback(aten.replication_pad2d_backward)
  2293. make_fallback(aten.upsample_linear1d_backward)
  2294. make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
  2295. make_fallback(aten.upsample_trilinear3d_backward)
  2296. make_fallback(aten.grid_sampler_2d_backward, require_dense)
  2297. make_fallback(aten._pdist_backward)
  2298. # 5) Impossible (missing triton/CPU features)
  2299. # Sorting / Sorting-like
  2300. make_fallback(aten.sort)
  2301. make_fallback(aten.sort.stable)
  2302. make_fallback(aten.kthvalue)
  2303. make_fallback(aten.topk)
  2304. make_fallback(aten.mode)
  2305. make_fallback(aten.median)
  2306. make_fallback(aten.nanmedian)
  2307. make_fallback(aten.randperm)
  2308. # see: https://github.com/pytorch/pytorch/pull/121354
  2309. make_fallback(aten.resize_)
  2310. make_fallback(aten.resize_as_)
  2311. # Linalg
  2312. make_fallback(aten._linalg_det)
  2313. make_fallback(aten.linalg_householder_product)
  2314. make_fallback(aten.linalg_inv_ex)
  2315. make_fallback(aten.linalg_ldl_factor_ex)
  2316. make_fallback(aten.linalg_ldl_solve)
  2317. make_fallback(aten.linalg_lu)
  2318. make_fallback(aten.linalg_lu_factor_ex)
  2319. make_fallback(aten.linalg_lu_solve)
  2320. make_fallback(aten.linalg_matrix_exp)
  2321. make_fallback(aten.linalg_qr)
  2322. make_fallback(aten._linalg_slogdet)
  2323. make_fallback(aten._linalg_solve_ex)
  2324. make_fallback(aten.linalg_solve_triangular)
  2325. make_fallback(aten._linalg_svd)
  2326. make_fallback(aten.lu_unpack)
  2327. make_fallback(aten.ormqr)
  2328. make_fallback(aten._linalg_check_errors)
  2329. make_fallback(aten.linalg_pinv.atol_rtol_tensor)
  2330. make_fallback(aten._linalg_eigh)
  2331. make_fallback(aten.triangular_solve)
  2332. make_fallback(aten.linalg_cholesky_ex)
  2333. make_fallback(aten.cholesky_inverse)
  2334. make_fallback(aten.cholesky_solve)
  2335. make_fallback(aten.geqrf)
  2336. make_fallback(aten._fft_r2c) # needs complex as well
  2337. # Data dependent (are these necessary?)
  2338. make_fallback(aten.nonzero.default)
  2339. # Misc
  2340. make_fallback(aten.gcd.default, warn=False)
  2341. make_fallback(aten._thnn_fused_lstm_cell, require_dense)
  2342. make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
  2343. make_fallback(torch._prims.rng_prims.run_with_rng_state)
  2344. make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state)
  2345. # Implemented / Half implemented
  2346. # Scans. Implemented for CUDA, missing CPU
  2347. make_fallback(aten.masked_scatter)
  2348. make_fallback(aten.masked_scatter_backward)
  2349. # Complex number support
  2350. make_fallback(aten.view_as_complex, require_contiguous)
  2351. make_fallback(aten.angle) # needs complex
  2352. # Needs efficentzerotensor
  2353. make_fallback(aten._efficientzerotensor)
  2354. # Needs Sparse
  2355. make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
  2356. make_fallback(aten.to_sparse)
  2357. make_fallback(aten._to_sparse)
  2358. # Needs dimname support
  2359. make_fallback(aten.zeros.names)
  2360. # 6) Pattern-matched
  2361. make_fallback(
  2362. aten._scaled_dot_product_efficient_attention.default,
  2363. sdpa_constraint,
  2364. warn=False,
  2365. )
  2366. make_fallback(
  2367. aten._scaled_dot_product_efficient_attention_backward.default,
  2368. sdpa_constraint,
  2369. warn=False,
  2370. )
  2371. make_fallback(
  2372. aten._scaled_dot_product_flash_attention.default,
  2373. sdpa_constraint,
  2374. warn=False,
  2375. )
  2376. make_fallback(
  2377. aten._scaled_dot_product_flash_attention_backward.default,
  2378. sdpa_constraint,
  2379. warn=False,
  2380. )
  2381. make_fallback(
  2382. aten._scaled_dot_product_cudnn_attention.default,
  2383. sdpa_constraint,
  2384. warn=False,
  2385. )
  2386. make_fallback(
  2387. aten._scaled_dot_product_cudnn_attention_backward.default,
  2388. sdpa_constraint,
  2389. warn=False,
  2390. )
  2391. make_fallback(
  2392. aten._scaled_dot_product_flash_attention_for_cpu.default,
  2393. sdpa_constraint,
  2394. warn=False,
  2395. )
  2396. make_fallback(
  2397. aten._scaled_dot_product_flash_attention_for_cpu_backward.default,
  2398. sdpa_constraint,
  2399. warn=False,
  2400. )
  2401. make_fallback(
  2402. aten._scaled_dot_product_fused_attention_overrideable.default,
  2403. sdpa_constraint,
  2404. warn=False,
  2405. )
  2406. make_fallback(
  2407. aten._scaled_dot_product_fused_attention_overrideable_backward.default,
  2408. sdpa_constraint,
  2409. warn=False,
  2410. )
  2411. make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
  2412. make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
  2413. make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)
  2414. make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
  2415. # index_reduce requires fallback when use_scatter_fallback(...) returns True
  2416. make_fallback(aten.index_reduce)
  2417. make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)
  2418. # Register with type_promotion_kind None.
  2419. # For example, fp16.copy_(fp32) should **not** promote the first input's dtype.
  2420. @register_lowering(aten.copy, type_promotion_kind=None)
  2421. def copy(self, src, non_blocking=False):
  2422. x = src
  2423. if self.get_device() != src.get_device():
  2424. x = to_device(x, self.get_device())
  2425. if self.get_dtype() != src.get_dtype():
  2426. x = to_dtype(x, self.get_dtype())
  2427. if self.get_size() != src.get_size():
  2428. out = expand(x, self.get_size())
  2429. return clone(out)
  2430. return clone(x)
  2431. @register_lowering(aten.clone)
  2432. def clone(x, *, memory_format=None):
  2433. # TODO(jansel): memory format
  2434. return Pointwise.create(
  2435. device=x.get_device(),
  2436. dtype=x.get_dtype(),
  2437. inner_fn=x.make_loader(),
  2438. ranges=list(x.get_size()),
  2439. )
  2440. def clone_preserve_reinterpret_view(x):
  2441. reinterpret_view_layouts = []
  2442. if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView):
  2443. x = x.data # unwrap TensorBox
  2444. while isinstance(x, ir.ReinterpretView):
  2445. reinterpret_view_layouts.append(x.get_layout())
  2446. x = x.data
  2447. x = TensorBox(x)
  2448. x = clone(x)
  2449. if reinterpret_view_layouts:
  2450. x = x.data # unwrap TensorBox
  2451. for layout in reinterpret_view_layouts[::-1]:
  2452. x = ir.ReinterpretView(data=x, layout=layout)
  2453. x = TensorBox(x)
  2454. return x
  2455. if hasattr(aten, "lift_fresh_copy"):
  2456. register_lowering(aten.lift_fresh_copy)(clone)
  2457. @register_lowering(prims.iota)
  2458. def iota(
  2459. length,
  2460. *,
  2461. start,
  2462. step,
  2463. dtype,
  2464. device,
  2465. requires_grad,
  2466. ):
  2467. def fn(index):
  2468. return ops.index_expr(step * index[0] + start, dtype=dtype)
  2469. return Pointwise.create(
  2470. device=decode_device(device),
  2471. dtype=dtype,
  2472. inner_fn=fn,
  2473. ranges=[length],
  2474. )
  2475. @register_lowering(aten.select_scatter, type_promotion_kind=None)
  2476. def select_scatter(x, src, dim: int, index: int):
  2477. assert x.get_dtype() == src.get_dtype()
  2478. x_loader = x.make_loader()
  2479. dim = _validate_dim(x, dim, 0)
  2480. if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
  2481. index = index + x.get_size()[dim]
  2482. V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type]
  2483. V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
  2484. src = expand(unsqueeze(src, dim), x.get_size())
  2485. src_loader = src.make_loader()
  2486. def inner_fn(idx):
  2487. return ops.where(
  2488. ops.eq(
  2489. ops.index_expr(idx[dim], torch.int32),
  2490. ops.index_expr(index, torch.int32),
  2491. ),
  2492. src_loader(idx),
  2493. x_loader(idx),
  2494. )
  2495. return Pointwise.create(
  2496. device=x.get_device(),
  2497. dtype=x.get_dtype(),
  2498. inner_fn=inner_fn,
  2499. ranges=list(x.get_size()),
  2500. )
  2501. @register_lowering(aten.slice_scatter, type_promotion_kind=None)
  2502. def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
  2503. src = to_dtype(src, x.get_dtype())
  2504. x_loader = x.make_loader()
  2505. dim = _validate_dim(x, dim, 0)
  2506. dim_size = x.get_size()[dim]
  2507. start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
  2508. src_size = list(x.get_size())
  2509. src_size[dim] = FloorDiv(end - start + (step - 1), step)
  2510. src = expand(src, src_size)
  2511. src_loader = src.make_loader()
  2512. def inner_fn(idx):
  2513. if start == 0 and end == dim_size and step == 1:
  2514. # selecting every element is the same as just src.clone()
  2515. return src_loader(idx)
  2516. idx_dim = ops.index_expr(idx[dim], torch.int64)
  2517. src_idx = list(idx)
  2518. src_idx[dim] = FloorDiv(idx[dim] - start, step)
  2519. mask = []
  2520. if start != 0:
  2521. mask.append(
  2522. ops.ge(
  2523. idx_dim,
  2524. ops.index_expr(sympy.expand(start), torch.int64),
  2525. )
  2526. )
  2527. if end != dim_size:
  2528. mask.append(
  2529. ops.lt(
  2530. idx_dim,
  2531. ops.index_expr(sympy.expand(end), torch.int64),
  2532. )
  2533. )
  2534. if step != 1:
  2535. mask.append(
  2536. ops.eq(
  2537. ops.index_expr(
  2538. ModularIndexing(idx[dim] - start, 1, step), torch.int64
  2539. ),
  2540. ops.constant(0, torch.int64),
  2541. )
  2542. )
  2543. assert mask
  2544. mask = functools.reduce(ops.and_, mask)
  2545. src_val = ops.masked(
  2546. mask,
  2547. lambda: src_loader(src_idx),
  2548. 0 if is_integer_type(x) else 0.0,
  2549. )
  2550. return ops.where(
  2551. mask,
  2552. src_val,
  2553. x_loader(idx),
  2554. )
  2555. return Pointwise.create(
  2556. device=x.get_device(),
  2557. dtype=x.get_dtype(),
  2558. inner_fn=inner_fn,
  2559. ranges=list(x.get_size()),
  2560. )
  2561. def _unwrap(x):
  2562. if isinstance(x, (list, tuple)) and len(x) > 0:
  2563. return _unwrap(x[0])
  2564. return x
  2565. @register_lowering([torch.tensor, aten.scalar_tensor])
  2566. def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
  2567. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  2568. assert_nyi(not pin_memory, "pin_memory")
  2569. if isinstance(_unwrap(data), int):
  2570. dtype = dtype or torch.int64
  2571. else:
  2572. dtype = dtype or torch.get_default_dtype()
  2573. ranges: list[sympy.Expr] = []
  2574. if isinstance(data, sympy.Basic):
  2575. def inner_fn(index):
  2576. return ops.index_expr(data, dtype)
  2577. elif isinstance(data, (float, int)):
  2578. def inner_fn(index):
  2579. return ops.constant(data, dtype)
  2580. elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
  2581. # inline small tensors
  2582. ranges.append(sympy.Integer(len(data)))
  2583. def inner_fn(index):
  2584. def binary_search(start, end):
  2585. assert start < end
  2586. if end - start == 1:
  2587. return ops.constant(data[start], dtype)
  2588. mid = (end - start) // 2 + start
  2589. return ops.where(
  2590. ops.lt(
  2591. ops.index_expr(index[0], torch.int64),
  2592. ops.constant(mid, torch.int64),
  2593. ),
  2594. binary_search(start, mid),
  2595. binary_search(mid, end),
  2596. )
  2597. if len(data) == 0:
  2598. return ops.constant(0, dtype)
  2599. return binary_search(0, len(data))
  2600. else:
  2601. return V.graph.add_tensor_constant(
  2602. torch.tensor(data, dtype=dtype, device=device)
  2603. )
  2604. return Pointwise.create(
  2605. device=decode_device(device),
  2606. dtype=dtype,
  2607. inner_fn=inner_fn,
  2608. ranges=ranges,
  2609. )
  2610. @register_lowering(torch.as_tensor)
  2611. def as_tensor(data, dtype=None, device=None):
  2612. if isinstance(data, TensorBox):
  2613. if dtype is not None:
  2614. data = to_dtype(data, dtype)
  2615. if device is not None:
  2616. data = to_device(data, device)
  2617. return data
  2618. return tensor(data, dtype=dtype, device=device)
  2619. @register_lowering(torch.LongTensor)
  2620. def long_tensor(data):
  2621. return tensor(data, dtype=torch.int64)
  2622. @register_lowering(aten._local_scalar_dense)
  2623. def _local_scalar_dense(data):
  2624. # This is interesting! Most lowerings return tensors, so you can just
  2625. # return the buffer you allocated and it will get used (or not used, if
  2626. # it's dead.) But _local_scalar_dense (aka item) returns an int,
  2627. # not a Tensor, so you would have a type mismatch if you return a buffer;
  2628. # we are obligated to return a sympy expression instead. However,
  2629. # we need to actually codegen the .item() call somehow. We do this
  2630. # by registering a faux buffer for the DynamicScalar IR node, which is
  2631. # solely responsible for generating this .item(). The buffer is
  2632. # not used for anything (notice we discard it); at codegen time,
  2633. # the "buffer" just gets assigned None.
  2634. unbacked_bindings = resolve_unbacked_bindings(
  2635. V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
  2636. )
  2637. assert unbacked_bindings is not None
  2638. assert len(unbacked_bindings) == 1, unbacked_bindings
  2639. # NB: Have to be very careful here. V.graph.current_node.meta["val"]
  2640. # seemingly also contains a symbol which you want to do binding for,
  2641. # but it actually isn't. In particular, if we have later performed
  2642. # a deferred runtime assert saying that u0 == s0, you will actually
  2643. # see s0 from expr! This is bad because we need to actually generate
  2644. # the assert that says u0 == s0, so we need to know where to get u0
  2645. # from (this call). In particular, we must use unbacked_bindings, which
  2646. # is guaranteed to have the original, unreplaced symbol in question.
  2647. #
  2648. # NB2: Another thing we have to be very careful about are symbol bindings
  2649. # that require nontrivial refinement, e.g., when you have a binding site
  2650. # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division
  2651. # in order to appropriately bind u0. This is communicated via the keypath
  2652. # in unbacked_bindings, and we need to hold onto it in order to generate
  2653. # code appropriately for this case.
  2654. binding_sym, keypath = next(iter(unbacked_bindings.items()))
  2655. buffer = ir.DynamicScalar(binding_sym, keypath, data)
  2656. buffer.name = V.graph.register_buffer(buffer)
  2657. V.graph.register_operation(buffer)
  2658. # NB: the replaced expr is OK to use directly downstream, we want
  2659. # simplifications in this case!
  2660. val = V.graph.current_node.meta["val"]
  2661. if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
  2662. return val.node.expr
  2663. else:
  2664. return sympy.sympify(val)
  2665. @register_lowering(aten._assert_scalar)
  2666. def _assert_scalar(data, msg):
  2667. # NB: These will be handled at codegen time
  2668. # Not sure if we are guaranteed to be able to serve out truth from the
  2669. # deferred_runtime_asserts, TODO: try this assert out
  2670. # See [NOTE] Codegen runtime asserts in Inductor
  2671. # assert bool(data.scalar), data
  2672. return None
  2673. @register_lowering(aten._assert_tensor_metadata)
  2674. def _assert_tensor_metadata(
  2675. a, size=None, stride=None, dtype=None, *, device=None, layout=None
  2676. ):
  2677. return None
  2678. def _full(fill_value, device, dtype, size):
  2679. value = fill_value
  2680. if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
  2681. value = value.value
  2682. if isinstance(value, (int, float)):
  2683. def inner_fn(index):
  2684. return ops.constant(value, dtype)
  2685. elif isinstance(value, sympy.Basic):
  2686. def inner_fn(index):
  2687. return ops.index_expr(value, dtype)
  2688. else:
  2689. assert len(value.get_size()) == 0
  2690. value_loader = value.make_loader()
  2691. def inner_fn(index):
  2692. return value_loader([])
  2693. return Pointwise.create(
  2694. device=device,
  2695. dtype=dtype,
  2696. inner_fn=inner_fn,
  2697. ranges=list(size),
  2698. )
  2699. def full_like(x, fill_value, **kwargs):
  2700. return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
  2701. def tensor_constructor(fill_value):
  2702. # torch.zeros, torch.ones, etc
  2703. def inner(
  2704. *size,
  2705. names=None,
  2706. dtype=None,
  2707. device=None,
  2708. layout=None,
  2709. pin_memory=False,
  2710. memory_format=None,
  2711. ):
  2712. assert_nyi(names is None, "named tensors")
  2713. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  2714. assert_nyi(not pin_memory, "pin_memory")
  2715. device = decode_device(device)
  2716. dtype = dtype or torch.get_default_dtype()
  2717. if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
  2718. size = tuple(size[0])
  2719. # See https://github.com/pytorch/pytorch/issues/118102
  2720. # All sizes at lowering time should be sympy.Symbol, not SymInt!
  2721. for s in size:
  2722. assert not isinstance(s, torch.SymInt)
  2723. size = [sympy.expand(s) for s in size]
  2724. return _full(fill_value, device, dtype, size)
  2725. return inner
  2726. @register_lowering([torch.empty, aten.empty])
  2727. def empty(
  2728. *size,
  2729. names=None,
  2730. dtype=None,
  2731. layout=None,
  2732. device=None,
  2733. pin_memory=None,
  2734. memory_format=None,
  2735. ):
  2736. assert_nyi(names is None, "named tensors")
  2737. device = decode_device(device)
  2738. if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
  2739. size = tuple(size[0])
  2740. return empty_strided(
  2741. size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  2742. )
  2743. def create_tensor_like(creation_fn):
  2744. """
  2745. Shim to convert X_like(...) into X(...). For example zeros_like() into zeros().
  2746. """
  2747. def _constant_like(
  2748. x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None
  2749. ):
  2750. assert_nyi(not pin_memory, "pin_memory")
  2751. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  2752. if dtype is None:
  2753. dtype = x.get_dtype()
  2754. else:
  2755. dtype = decode_dtype(dtype)
  2756. device = device or x.get_device()
  2757. size = list(x.get_size())
  2758. return creation_fn(
  2759. size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
  2760. )
  2761. return _constant_like
  2762. def constant_like(fill_value):
  2763. return create_tensor_like(tensor_constructor(fill_value))
  2764. empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty))
  2765. ones_like = create_tensor_like(tensor_constructor(1))
  2766. zeros_like = create_tensor_like(tensor_constructor(0))
  2767. def new_constant(fill_value):
  2768. def _new_constant(
  2769. x, size, *, dtype=None, layout=None, device=None, pin_memory=None
  2770. ):
  2771. assert isinstance(size, (list, tuple))
  2772. assert_nyi(not pin_memory, "pin_memory")
  2773. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  2774. dtype = decode_dtype(dtype) or x.get_dtype()
  2775. device = device or x.get_device()
  2776. size = [sympy.Integer(s) for s in size]
  2777. return _full(fill_value, decode_device(device), dtype, size)
  2778. return _new_constant
  2779. @register_lowering(aten.new_empty)
  2780. def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None):
  2781. if dtype is None:
  2782. dtype = x.get_dtype()
  2783. if device is None:
  2784. device = x.get_device()
  2785. return empty_strided(
  2786. size,
  2787. None,
  2788. dtype=dtype,
  2789. layout=layout,
  2790. device=decode_device(device),
  2791. pin_memory=pin_memory,
  2792. )
  2793. @register_lowering(aten.empty_strided)
  2794. def empty_strided(
  2795. size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
  2796. ):
  2797. assert isinstance(size, (list, tuple))
  2798. assert isinstance(stride, (list, tuple, type(None)))
  2799. assert_nyi(not pin_memory, "pin_memory")
  2800. assert_nyi(layout in (None, torch.strided), f"layout={layout}")
  2801. dtype = decode_dtype(dtype) or torch.get_default_dtype()
  2802. device = device or torch.tensor(0.0).device
  2803. device = decode_device(device)
  2804. pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
  2805. pointwise.realize()
  2806. buffer = pointwise.data.data
  2807. # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode
  2808. buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size))
  2809. assert isinstance(buffer, ir.ComputedBuffer)
  2810. size = [sympy.expand(s) for s in size]
  2811. stride = (
  2812. [sympy.expand(s) for s in stride]
  2813. if stride
  2814. else ir.FlexibleLayout.contiguous_strides(size)
  2815. )
  2816. buffer.layout = ir.FixedLayout(
  2817. device=device,
  2818. dtype=dtype,
  2819. size=size,
  2820. stride=stride,
  2821. )
  2822. return pointwise
  2823. @register_lowering(aten.new_empty_strided)
  2824. def new_empty_strided(
  2825. x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
  2826. ):
  2827. if dtype is None:
  2828. dtype = x.get_dtype()
  2829. if device is None:
  2830. device = x.get_device()
  2831. return empty_strided(
  2832. size,
  2833. stride,
  2834. dtype=dtype,
  2835. layout=layout,
  2836. device=decode_device(device),
  2837. pin_memory=pin_memory,
  2838. )
  2839. @register_lowering(prims.copy_strided.default)
  2840. def copy_strided(x, stride):
  2841. stride = [V.graph.sizevars.size_hint_or_throw(s) for s in stride]
  2842. stride_order = sorted(range(len(stride)), key=stride.__getitem__)
  2843. return ir.ExternKernel.require_stride_order(x, stride_order)
  2844. @register_lowering([torch.full, aten.full])
  2845. def full(size, fill_value, **kwargs):
  2846. assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition"
  2847. return tensor_constructor(fill_value)(size, **kwargs)
  2848. @register_lowering(aten.gather, type_promotion_kind=None)
  2849. def gather(x, dim, index, sparse_grad=False):
  2850. # sparse_grad doesn't affect forward computation,
  2851. # and backward tracing is taken care of by AOT Autograd
  2852. assert isinstance(x, TensorBox)
  2853. if index.get_numel() == 0:
  2854. # Empty index case. Return an empty array with the same shape
  2855. return new_empty(x, index.get_size())
  2856. size = x.get_size()
  2857. offset = len(size) == 0
  2858. dim = _validate_dim(x, dim, offset)
  2859. if offset:
  2860. x = expand(x, [1])
  2861. size = [1]
  2862. x_loader = x.make_loader()
  2863. index_loader = index.make_loader()
  2864. def fn(idx):
  2865. idx = list(idx)
  2866. gather_idx = ops.indirect_indexing(index_loader(idx), size[dim])
  2867. if len(idx) == 0:
  2868. idx = [gather_idx]
  2869. else:
  2870. idx[dim] = gather_idx
  2871. return x_loader(idx)
  2872. return Pointwise.create(
  2873. device=x.get_device(),
  2874. dtype=x.get_dtype(),
  2875. inner_fn=fn,
  2876. ranges=index.get_size(),
  2877. )
  2878. @register_lowering(aten.embedding, type_promotion_kind=None)
  2879. def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
  2880. if sparse:
  2881. return fallback_handler(aten.embedding.default)(
  2882. weight, indices, padding_idx, scale_grad_by_freq, sparse
  2883. )
  2884. assert not sparse
  2885. assert isinstance(weight, TensorBox)
  2886. assert isinstance(indices, TensorBox)
  2887. assert "int" in str(indices.get_dtype())
  2888. weight_loader = weight.make_loader()
  2889. indices_loader = indices.make_loader()
  2890. indices_ndim = len(indices.get_size())
  2891. weight_size = weight.get_size()
  2892. new_size = [*indices.get_size(), *weight_size[1:]]
  2893. def fn(idx):
  2894. assert len(idx) == len(new_size), f"{idx} != {new_size}"
  2895. var_index = indices_loader(idx[:indices_ndim])
  2896. weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [
  2897. *idx[indices_ndim:]
  2898. ]
  2899. return weight_loader(weight_idx)
  2900. return Pointwise.create(
  2901. device=weight.get_device(),
  2902. dtype=weight.get_dtype(),
  2903. inner_fn=fn,
  2904. ranges=new_size,
  2905. )
  2906. def check_and_broadcast_indices(indices, device):
  2907. assert all(
  2908. i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
  2909. for i in indices
  2910. if i is not None
  2911. ), (
  2912. f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
  2913. )
  2914. if any(
  2915. i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
  2916. ):
  2917. raise NotImplementedError("Fallback for bool indices")
  2918. valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)]
  2919. assert len(valid_idxs) > 0, "requires at least 1 non-None index"
  2920. new_indices = [None] * len(indices)
  2921. for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])):
  2922. # Eager allows indices to be CPU tensor when running on CUDA
  2923. # FIXME: Calling to_device(x, device) should work but
  2924. # test_advancedindex_mixed_cpu_devices still fails
  2925. if x.get_device() != device:
  2926. raise NotImplementedError("Fallback when indices is on a different device")
  2927. new_indices[i] = x
  2928. return new_indices, valid_idxs
  2929. def index_output_size_and_inner_fn(
  2930. x_size,
  2931. indices,
  2932. tensor_indices,
  2933. tensor_size,
  2934. indices_loaders,
  2935. indexed_size,
  2936. x_loader,
  2937. check,
  2938. wrap_neg=True,
  2939. ):
  2940. # Note that behavior of indexing differs when there are non consecutive
  2941. # tensors. In this case, the tensor index is pulled to the beginning.
  2942. #
  2943. # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7)
  2944. # x = torch.tensor[1,2]
  2945. # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will
  2946. # be pulled to the front.
  2947. non_consecutive_tensors = False
  2948. for previous, current in zip(tensor_indices, tensor_indices[1:]):
  2949. if current - previous != 1:
  2950. non_consecutive_tensors = True
  2951. output_size = [x_size[i] for i, val in enumerate(indices) if val is None]
  2952. output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]]
  2953. first_tensor_index = tensor_indices[0]
  2954. if non_consecutive_tensors:
  2955. output_size = tensor_size + output_size
  2956. else:
  2957. output_size = (
  2958. output_size[:first_tensor_index]
  2959. + tensor_size
  2960. + output_size[first_tensor_index:]
  2961. )
  2962. def fn(idx):
  2963. assert len(idx) == len(output_size)
  2964. assert len(indices_loaders) == len(indexed_size)
  2965. rank = len(tensor_size)
  2966. new_index = []
  2967. first_tensor_index = tensor_indices[0]
  2968. start_offset = 0 if non_consecutive_tensors else first_tensor_index
  2969. next_idx = 0
  2970. for i in range(tensor_indices[-1] + 1):
  2971. if i == start_offset:
  2972. next_idx += rank
  2973. if indices[i] is None:
  2974. assert next_idx < len(idx)
  2975. new_index.append(idx[next_idx])
  2976. next_idx += 1
  2977. else:
  2978. loader = indices_loaders[i]
  2979. assert loader is not None
  2980. size = indexed_size[i]
  2981. new_index.append(
  2982. ops.indirect_indexing(
  2983. loader(idx[start_offset : start_offset + rank]),
  2984. size,
  2985. check=check,
  2986. wrap_neg=wrap_neg,
  2987. )
  2988. )
  2989. new_index = [
  2990. *new_index,
  2991. *idx[next_idx:],
  2992. ]
  2993. return new_index if x_loader is None else x_loader(new_index)
  2994. return output_size, fn
  2995. def index_impl(x, indices, check):
  2996. output_size, inner_fn, _ = index_impl_helper(x, indices, check)
  2997. return Pointwise.create(
  2998. device=x.get_device(),
  2999. dtype=x.get_dtype(),
  3000. inner_fn=inner_fn,
  3001. ranges=output_size,
  3002. )
  3003. def index_impl_helper(x, indices, check, wrap_neg=True):
  3004. assert isinstance(indices, (list, tuple))
  3005. x_loader = x.make_loader()
  3006. indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device())
  3007. assert len(tensor_indices) > 0, "Must have at least one valid idx"
  3008. indices_loaders = [i.make_loader() if i is not None else None for i in indices]
  3009. # no guards on output size, all the guards are set in broadcast_tensors
  3010. # We can use the first one since they are all required to be the same size
  3011. tensor_size = list(indices[tensor_indices[0]].get_size())
  3012. x_size = x.get_size()
  3013. indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None]
  3014. if check and 0 in indexed_size and 0 not in tensor_size:
  3015. raise IndexError("index is out of bounds for dimension with size 0")
  3016. indexed_size = [x_size[i] for i in range(len(indices))]
  3017. output_size, index_inner_fn = index_output_size_and_inner_fn(
  3018. x_size,
  3019. indices,
  3020. tensor_indices,
  3021. tensor_size,
  3022. indices_loaders,
  3023. indexed_size,
  3024. None,
  3025. check=check,
  3026. wrap_neg=wrap_neg,
  3027. )
  3028. def inner_fn(idx):
  3029. return x_loader(index_inner_fn(idx))
  3030. return output_size, inner_fn, index_inner_fn
  3031. @register_lowering(aten.index, type_promotion_kind=None)
  3032. def index(x, indices):
  3033. try:
  3034. return index_impl(x, indices, check=True)
  3035. except NotImplementedError:
  3036. # Fallback to ATen for boolean indexing
  3037. x.realize()
  3038. return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)(
  3039. x, indices
  3040. )
  3041. @register_lowering(aten._unsafe_index, type_promotion_kind=None)
  3042. def _unsafe_index(x, indices):
  3043. return index_impl(x, indices, check=False)
  3044. # All the indexing decompositions are written in terms of index, index_put, and index_put_
  3045. # We cannot have this lowering as a decomposition as it introduces
  3046. # mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead
  3047. # code elimination and common subexpression elimination optimizations, which
  3048. # assume graphs to be side-effect free. More details at
  3049. # https://github.com/pytorch/torchdynamo/issues/1235
  3050. # and
  3051. # https://github.com/pytorch/torchdynamo/issues/1863
  3052. @register_lowering(aten.index_put, type_promotion_kind=None)
  3053. def index_put(x, indices, values, accumulate=False):
  3054. return index_put_impl_(
  3055. clone(x), indices, values, accumulate, check=True, may_realize=False
  3056. )
  3057. @register_lowering(aten._unsafe_index_put)
  3058. def _unsafe_index_put(x, indices, values, accumulate=False):
  3059. return index_put_impl_(
  3060. clone(x), indices, values, accumulate, check=False, may_realize=False
  3061. )
  3062. def index_put_as_masked_fill(self, indices, value, accumulate):
  3063. if value.get_device() != self.get_device():
  3064. value = to_device(value, self.get_device())
  3065. if accumulate:
  3066. value = add(self, value)
  3067. return mutate_to(self, where(indices[0], value, self))
  3068. def index_put_fallback(self, indices, values, accumulate):
  3069. assert isinstance(V.graph.current_node.target, torch._ops.OpOverload)
  3070. ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate)
  3071. return self
  3072. @register_lowering(aten.index_put_, type_promotion_kind=None)
  3073. def index_put_(self, indices, values, accumulate=False):
  3074. return index_put_impl_(
  3075. self, indices, values, accumulate, check=True, may_realize=True
  3076. )
  3077. @register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None)
  3078. def _unsafe_index_put_(self, indices, values, accumulate=False):
  3079. return index_put_impl_(
  3080. self, indices, values, accumulate, check=False, may_realize=True
  3081. )
  3082. def index_put_impl_(self, indices, values, accumulate, check, may_realize=False):
  3083. if may_realize:
  3084. def try_get_name(x):
  3085. if isinstance(x, ir.TensorBox):
  3086. x = x.data
  3087. if isinstance(x, ir.BaseView):
  3088. x = x.unwrap_view()
  3089. if isinstance(x, ir.StorageBox):
  3090. x = x.data
  3091. return x.get_name() if isinstance(x, ir.Buffer) else None
  3092. def indice_slice_from_randperm(indice):
  3093. # Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660
  3094. # For this specific pattern, indices is unique as coming from torch.randperm.
  3095. # However, as the content of the indices is unknown, we have to check this specific pattern.
  3096. if isinstance(indice, TensorBox) and isinstance(indice.data, ir.BaseView):
  3097. indice = indice.data.unwrap_view()
  3098. return (
  3099. isinstance(indice, ir.StorageBox)
  3100. and isinstance(indice.data, ir.ExternKernel)
  3101. and getattr(indice.data, "fx_node", None)
  3102. and indice.data.fx_node.target == torch.ops.aten.randperm.default
  3103. )
  3104. return False
  3105. if try_get_name(self) in values.get_read_names() and not all(
  3106. indice_slice_from_randperm(indice) for indice in indices
  3107. ):
  3108. # Fix issue: https://github.com/pytorch/pytorch/issues/138908
  3109. # When self and values have memory overlapping, indices may
  3110. # contain duplicate values, potentially causing incorrect results since
  3111. # the load of `values` might contain modified value from the store of `self`.
  3112. # To address this, store values in a temporary buffer in such cases.
  3113. values.realize()
  3114. # Dispatch to masked fill for single boolean index with single value
  3115. if (
  3116. values.get_numel() == 1
  3117. and len(indices) == 1
  3118. and indices[0].get_dtype() in (torch.bool, torch.uint8)
  3119. ):
  3120. mask = indices[0]
  3121. for _ in range(len(mask.get_size()), len(self.get_size())):
  3122. mask = unsqueeze(mask, -1)
  3123. return index_put_as_masked_fill(self, [mask], values, accumulate)
  3124. # Fallback in torch deterministic mode
  3125. if torch.are_deterministic_algorithms_enabled():
  3126. return index_put_fallback(self, indices, values, accumulate)
  3127. # Fallback if there is a boolean index
  3128. for index in indices:
  3129. if index is not None and index.get_dtype() in (torch.bool, torch.uint8):
  3130. return index_put_fallback(self, indices, values, accumulate)
  3131. x_size = self.get_size()
  3132. x_ndim = len(x_size)
  3133. if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()):
  3134. # self is an scalar Tensor
  3135. if x_ndim == 0:
  3136. self = view(self, [1])
  3137. self = index_put_fallback(self, indices, values, accumulate)
  3138. if x_ndim == 0:
  3139. self = view(self, [])
  3140. return self
  3141. values = to_dtype(values, self.get_dtype())
  3142. try:
  3143. # Note that code will only get here when dtype is uint32
  3144. indices, tensor_indices = check_and_broadcast_indices(
  3145. indices, self.get_device()
  3146. )
  3147. except NotImplementedError:
  3148. return index_put_fallback(self, indices, values, accumulate)
  3149. indices_loaders = [i.make_loader() if i is not None else None for i in indices]
  3150. assert isinstance(self, TensorBox)
  3151. self.realize()
  3152. # self is an scalar Tensor
  3153. if x_ndim == 0:
  3154. self = view(self, [1])
  3155. # We can use the first one since they are all required to be the same size
  3156. tensor_size = list(indices[tensor_indices[0]].get_size())
  3157. indexed_size = [x_size[i] for i in range(len(indices))]
  3158. expected_vals_size, inner_fn = index_output_size_and_inner_fn(
  3159. x_size,
  3160. indices,
  3161. tensor_indices,
  3162. tensor_size,
  3163. indices_loaders,
  3164. indexed_size,
  3165. None,
  3166. check=check,
  3167. )
  3168. values = expand(values, expected_vals_size)
  3169. # all guards are set above during broadcast_tensors and expand
  3170. device = self.get_device()
  3171. assert device is not None
  3172. scatter = ir.Scatter(
  3173. device=device,
  3174. dtype=self.get_dtype(),
  3175. inner_fn=values.make_loader(),
  3176. ranges=expected_vals_size, # iter_ranges,
  3177. output_indexer=inner_fn,
  3178. scatter_mode="atomic_add" if accumulate else None,
  3179. )
  3180. buffer = ir.ComputedBuffer(
  3181. name=None,
  3182. layout=ir.MutationLayoutSHOULDREMOVE(self),
  3183. data=scatter,
  3184. )
  3185. buffer.name = V.graph.register_buffer(buffer)
  3186. V.graph.register_operation(buffer)
  3187. if x_ndim == 0:
  3188. self = view(self, [])
  3189. return self
  3190. fallback__unsafe_masked_index = fallback_handler(
  3191. aten._unsafe_masked_index.default, add_to_fallback_set=False
  3192. )
  3193. fallback__unsafe_masked_index_put_accumulate = fallback_handler(
  3194. aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False
  3195. )
  3196. @register_lowering(aten._unsafe_masked_index, type_promotion_kind=None)
  3197. def _unsafe_masked_index(self, mask, indices, fill):
  3198. ranges, _, _unsafe_index_fn = index_impl_helper(
  3199. self, indices, check=False, wrap_neg=False
  3200. )
  3201. mask_loader = mask.make_loader()
  3202. self_loader = self.make_loader()
  3203. def inner_fn(idx):
  3204. if mask.dtype != torch.bool:
  3205. mask_val = ops.to_dtype(mask_loader(idx), torch.bool)
  3206. else:
  3207. mask_val = mask_loader(idx)
  3208. return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill)
  3209. return Pointwise.create(
  3210. device=self.get_device(),
  3211. dtype=self.get_dtype(),
  3212. inner_fn=inner_fn,
  3213. ranges=ranges,
  3214. )
  3215. @register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None)
  3216. def _unsafe_masked_index_put_accumulate(x, mask, indices, values):
  3217. masked_value = where(mask, values, 0)
  3218. shape = x.get_size()
  3219. clamped_indices = [
  3220. clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None
  3221. for i in range(len(indices))
  3222. ]
  3223. # TODO: use a masked store for this. currently only triton
  3224. # supports masked stores and cpp backend does not.
  3225. return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True)
  3226. @make_pointwise
  3227. def clamp(a, min, max):
  3228. return ops.maximum(min, ops.minimum(max, a))
  3229. @register_lowering(aten.as_strided_scatter, type_promotion_kind=None)
  3230. def as_strided_scatter(self, src, size, stride, storage_offset=None):
  3231. output = clone(self)
  3232. output_view = as_strided(output, size, stride, storage_offset)
  3233. copy_(output_view, src)
  3234. return output
  3235. @register_lowering(aten.scatter, type_promotion_kind=None)
  3236. def scatter(x, dim: int, index, src, **kwargs):
  3237. return scatter_(clone(x), dim, index, src, **kwargs)
  3238. def scatter_fallback(
  3239. op_overload: torch._ops.OpOverload,
  3240. self,
  3241. dim: int,
  3242. index,
  3243. src,
  3244. *,
  3245. reduce: Optional[str] = None,
  3246. include_self: bool = True,
  3247. ):
  3248. src_is_tensor = isinstance(src, TensorBox)
  3249. if use_scatter_fallback(
  3250. op_overload,
  3251. reduce,
  3252. self.get_dtype(),
  3253. cast(torch.dtype, src.get_dtype() if src_is_tensor else type(src)),
  3254. src.get_device().type if src_is_tensor else "not impl",
  3255. src_is_tensor,
  3256. ):
  3257. ir.ScatterFallback(
  3258. op_overload,
  3259. self,
  3260. dim,
  3261. index,
  3262. src,
  3263. reduce=reduce,
  3264. include_self=include_self,
  3265. )
  3266. return self
  3267. return None
  3268. @register_lowering(aten.scatter_, type_promotion_kind=None)
  3269. def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None):
  3270. assert reduce in (None, "add", "multiply")
  3271. if reduce is None:
  3272. op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr]
  3273. fallback_result = scatter_fallback(
  3274. op_overload, self, dim, index, src, reduce=reduce
  3275. )
  3276. if fallback_result is not None:
  3277. return fallback_result
  3278. if reduce == "add":
  3279. reduce = "sum"
  3280. elif reduce == "multiply":
  3281. reduce = "prod"
  3282. return scatter_reduce_(self, dim, index, src, reduce)
  3283. @register_lowering(aten.scatter_add, type_promotion_kind=None)
  3284. def scatter_add(x, dim: int, index, src):
  3285. return scatter_add_(clone(x), dim, index, src)
  3286. @register_lowering(aten.scatter_add_, type_promotion_kind=None)
  3287. def scatter_add_(x, dim: int, index, src):
  3288. return scatter_reduce_(x, dim, index, src, "sum")
  3289. @register_lowering(aten.scatter_reduce, type_promotion_kind=None)
  3290. def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
  3291. return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
  3292. @register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
  3293. def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
  3294. assert reduce in (None, "sum", "prod", "mean", "amax", "amin")
  3295. assert (
  3296. len(aten.scatter_reduce_.overloads()) == 1
  3297. and "two" in aten.scatter_reduce_.overloads()
  3298. ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_"
  3299. if isinstance(src, Number):
  3300. src = full_like(self, src)
  3301. fallback_result = scatter_fallback(
  3302. aten.scatter_reduce_.two,
  3303. self,
  3304. dim,
  3305. index,
  3306. src,
  3307. reduce=reduce,
  3308. include_self=include_self,
  3309. )
  3310. if fallback_result:
  3311. return fallback_result
  3312. assert isinstance(self, TensorBox)
  3313. assert "int" in str(index.get_dtype())
  3314. ndim = len(self.get_size())
  3315. if ndim == 0:
  3316. self = view(self, [1])
  3317. if isinstance(src, TensorBox) and len(src.get_size()) == 0:
  3318. src = view(src, [1])
  3319. if isinstance(index, TensorBox) and len(index.get_size()) == 0:
  3320. index = view(index, [1])
  3321. if index.get_numel() == 0:
  3322. return self
  3323. dim = _validate_dim(self, dim)
  3324. self.realize()
  3325. index_loader = index.make_loader()
  3326. src_loader = src.make_loader() if isinstance(src, TensorBox) else None
  3327. def output_indexer(idx):
  3328. # self is captured from the end of the function, so it may have 0 dim
  3329. shape = self.get_size()
  3330. ndim = len(shape)
  3331. indirect_idx = list(idx)
  3332. indirect_idx[dim] = ops.indirect_indexing(
  3333. index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False
  3334. )
  3335. return indirect_idx
  3336. def fn(idx):
  3337. if src_loader:
  3338. return src_loader(idx)
  3339. else:
  3340. # src is a scalar
  3341. return ops.constant(src, self.get_dtype())
  3342. def backend_reduce_str(reduce):
  3343. if reduce == "sum":
  3344. return "atomic_add"
  3345. else:
  3346. # TODO: Need to support more reduction type
  3347. assert reduce is None
  3348. return None
  3349. device = self.get_device()
  3350. assert device is not None
  3351. if not include_self:
  3352. # zero out the corresponding elements first
  3353. zero_out = ir.Scatter(
  3354. device=device,
  3355. dtype=self.get_dtype(),
  3356. inner_fn=lambda index: ops.constant(0, self.get_dtype()),
  3357. ranges=index.get_size(),
  3358. output_indexer=output_indexer,
  3359. scatter_mode=None,
  3360. )
  3361. buffer = ir.ComputedBuffer(
  3362. name=None,
  3363. layout=ir.MutationLayoutSHOULDREMOVE(self),
  3364. data=zero_out,
  3365. )
  3366. buffer.name = V.graph.register_buffer(buffer)
  3367. V.graph.register_operation(buffer)
  3368. # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
  3369. # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
  3370. # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
  3371. scatter = ir.Scatter(
  3372. device=device,
  3373. dtype=self.get_dtype(),
  3374. inner_fn=fn,
  3375. ranges=index.get_size(),
  3376. output_indexer=output_indexer,
  3377. scatter_mode=backend_reduce_str(reduce),
  3378. )
  3379. buffer = ir.ComputedBuffer(
  3380. name=None,
  3381. layout=ir.MutationLayoutSHOULDREMOVE(self),
  3382. data=scatter,
  3383. )
  3384. buffer.name = V.graph.register_buffer(buffer)
  3385. V.graph.register_operation(buffer)
  3386. if ndim == 0:
  3387. self = view(self, [])
  3388. return self
  3389. def upsample_nearestnd(
  3390. x,
  3391. output_size,
  3392. scales_x: tuple[Optional[float], ...],
  3393. n: int = 2,
  3394. exact: bool = False,
  3395. ):
  3396. x.realize_hint() # elements are reused
  3397. x_loader = x.make_loader()
  3398. i_sizes = x.get_size()[-n:]
  3399. batch = x.get_size()[:-n]
  3400. i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes]
  3401. assert len(scales_x) == n
  3402. o_sizes = output_size
  3403. inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)]
  3404. for i, scale in enumerate(scales_x):
  3405. if scale is not None:
  3406. inv_scales[i] = 1.0 / scale
  3407. def scale_fn(x, scale, size):
  3408. # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5)
  3409. # = floor(scale * (output_index + 0.5))
  3410. # Nearest: input_index = floor(scale * output_index)
  3411. x = ops.index_expr(x, torch.float32)
  3412. if exact:
  3413. x = ops.add(x, ops.constant(0.5, torch.float32))
  3414. x = ops.mul(x, ops.constant(scale, torch.float32))
  3415. x = ops.to_dtype(x, torch.int32)
  3416. return ops.indirect_indexing(x, size, check=False)
  3417. def fn(idx):
  3418. x = idx[-n:]
  3419. b = idx[:-n]
  3420. return x_loader(
  3421. [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]]
  3422. )
  3423. return Pointwise.create(
  3424. device=x.get_device(),
  3425. dtype=x.get_dtype(),
  3426. inner_fn=fn,
  3427. ranges=[*batch, *o_sizes],
  3428. )
  3429. @register_lowering(aten.upsample_nearest1d.default)
  3430. def upsample_nearest1d(x, output_size, scales: Optional[float] = None):
  3431. return upsample_nearestnd(x, output_size, (scales,), n=1)
  3432. @register_lowering(aten._upsample_nearest_exact1d.default)
  3433. def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None):
  3434. return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True)
  3435. @register_lowering(aten.upsample_nearest2d.default)
  3436. def upsample_nearest2d(
  3437. x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
  3438. ):
  3439. return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2)
  3440. @register_lowering(aten._upsample_nearest_exact2d.default)
  3441. def _upsample_nearest_exact2d(
  3442. x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
  3443. ):
  3444. return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True)
  3445. @register_lowering(aten.upsample_nearest3d.default)
  3446. def upsample_nearest3d(
  3447. x,
  3448. output_size,
  3449. scales_d: Optional[float] = None,
  3450. scales_h: Optional[float] = None,
  3451. scales_w: Optional[float] = None,
  3452. ):
  3453. return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3)
  3454. @register_lowering(aten._upsample_nearest_exact3d.default)
  3455. def _upsample_nearest_exact3d(
  3456. x,
  3457. output_size,
  3458. scales_d: Optional[float] = None,
  3459. scales_h: Optional[float] = None,
  3460. scales_w: Optional[float] = None,
  3461. ):
  3462. return upsample_nearestnd(
  3463. x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True
  3464. )
  3465. def _create_constants(*args, dtype):
  3466. return tuple(ops.constant(a, dtype) for a in args)
  3467. @register_lowering(prims.rev.default)
  3468. def rev(x, dims):
  3469. # note - dims pre-canonicalized
  3470. x_loader = x.make_loader()
  3471. sizes = x.get_size()
  3472. def loader(idx):
  3473. idx = list(idx)
  3474. assert len(idx) == len(sizes)
  3475. for dim in dims:
  3476. idx[dim] = (sizes[dim] - 1) - idx[dim]
  3477. return x_loader(idx)
  3478. return Pointwise.create(
  3479. device=x.get_device(),
  3480. dtype=x.get_dtype(),
  3481. inner_fn=loader,
  3482. ranges=sizes,
  3483. )
  3484. def inplace_constant_pad_nd(
  3485. x: TensorBox, padding: Sequence[int], fill_value: float
  3486. ) -> Optional[TensorBox]:
  3487. """
  3488. This optimization changes the semantics of padding from 'clone'
  3489. style to 'view' style.
  3490. Thanks to functionalization, this change can still maintain numerical
  3491. correctness.
  3492. """
  3493. def _padding_can_be_fused():
  3494. """
  3495. Conservatively check if padding can be fused with downstream op.
  3496. 1. if the downstream op is a sum, then there is little benefit to
  3497. do inplace padding
  3498. 2. if the downstream op is a matmul, doing inplace padding can
  3499. save membw.
  3500. """
  3501. current_node = V.graph.current_node
  3502. if current_node is None:
  3503. return True # be conservative
  3504. users = tuple(current_node.users)
  3505. if len(users) == 1 and users[0].target in (
  3506. aten.mm.default,
  3507. aten.addmm.default,
  3508. ):
  3509. return False
  3510. return True # be conservative
  3511. if _padding_can_be_fused():
  3512. return None
  3513. # Only handle 2D case for now
  3514. if len(padding) != 4 or len(x.get_size()) != 2:
  3515. return None
  3516. # No harm to realize since we already know that
  3517. # the op can not be fused into the single user.
  3518. # It need to be realized later anyways.
  3519. x.realize()
  3520. # If x is a view (e.g. a SliceView), realizing it just realizing the
  3521. # underlying storage. x itself is still a view.
  3522. if (
  3523. not isinstance(x, ir.TensorBox)
  3524. or not isinstance(x.data, ir.StorageBox)
  3525. or not (
  3526. isinstance(x.data.data, ir.ComputedBuffer)
  3527. or (
  3528. config.can_inplace_pad_graph_input
  3529. and isinstance(x.data.data, ir.InputBuffer)
  3530. )
  3531. )
  3532. or not x.data.data.name
  3533. ):
  3534. return None
  3535. x.freeze_layout()
  3536. _, layout = ir.as_storage_and_layout(x)
  3537. strides = layout.stride
  3538. if strides[1] != 1:
  3539. return None
  3540. if padding[0] != 0 or padding[2] != 0 or padding[3] != 0:
  3541. return None
  3542. npad = padding[1]
  3543. if npad == 0:
  3544. return None
  3545. stride0 = strides[0]
  3546. rowsize = layout.size[1]
  3547. if stride0 < rowsize + npad:
  3548. return None
  3549. bufname = x.data.data.name
  3550. padded_size = [layout.size[0], layout.size[1] + npad]
  3551. V.graph.buffer_to_padded_size[bufname] = padded_size
  3552. resized_x = as_strided(
  3553. x,
  3554. padded_size,
  3555. layout.stride,
  3556. layout.offset,
  3557. )
  3558. sliced_x = slice_(resized_x, dim=1, start=rowsize, end=rowsize + npad)
  3559. fill_(sliced_x, fill_value)
  3560. counters["inductor"]["inplace_padding"] += 1
  3561. return resized_x
  3562. @register_lowering(aten.constant_pad_nd, type_promotion_kind=None)
  3563. def constant_pad_nd(x, padding, fill_value=0):
  3564. assert (len(padding) % 2) == 0
  3565. if all(p == 0 for p in padding):
  3566. return clone(x)
  3567. if config.inplace_padding:
  3568. out = inplace_constant_pad_nd(x, padding, fill_value)
  3569. if out:
  3570. return out
  3571. # fall through if can not inplace the padding
  3572. sizes = x.get_size()
  3573. bounds = list(reversed(list(zip(padding[::2], padding[1::2]))))
  3574. n = len(sizes) - len(bounds)
  3575. # if padding is a complicated expression, hoist it
  3576. bounds_precomp: list[tuple[sympy.Symbol, Any]] = []
  3577. for l, h in bounds:
  3578. bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
  3579. output_size = list(sizes[:n])
  3580. mask_sizes = []
  3581. for (low, high), size in zip(bounds, sizes[n:]):
  3582. mask_sizes.append(size)
  3583. output_size.append(sympy.expand(size + low + high))
  3584. assert len(output_size) == len(sizes)
  3585. fill_value = dtype_to_type(x.get_dtype())(fill_value)
  3586. def mask(index):
  3587. mask = []
  3588. for idx, (low, high), length in zip(index[n:], bounds, mask_sizes):
  3589. if low != 0:
  3590. mask.append(range_mask_low(idx, 0))
  3591. if high != 0:
  3592. mask.append(range_mask_high(idx, length))
  3593. mask = functools.reduce(ops.and_, mask)
  3594. return ops.masked(mask, lambda: x_loader(index), fill_value)
  3595. def offset_fn(index):
  3596. new_index = list(index[:n])
  3597. for idx, (low, _high) in zip(index[n:], bounds_precomp):
  3598. new_index.append(idx - low)
  3599. assert len(new_index) == len(index)
  3600. return mask(new_index)
  3601. x_loader = x.make_loader()
  3602. return Pointwise.create(
  3603. device=x.get_device(),
  3604. dtype=x.get_dtype(),
  3605. inner_fn=offset_fn,
  3606. ranges=output_size,
  3607. )
  3608. def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]):
  3609. return ops.ge(
  3610. ops.index_expr(i, torch.int64),
  3611. ops.index_expr(sympy.Integer(low), torch.int64),
  3612. )
  3613. def range_mask_high(i: sympy.Expr, high: sympy.Expr):
  3614. return ops.lt(
  3615. ops.index_expr(i, torch.int64),
  3616. ops.index_expr(high, torch.int64),
  3617. )
  3618. def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr):
  3619. return ops.and_(
  3620. range_mask_low(i, low),
  3621. range_mask_high(i, high),
  3622. )
  3623. def constant_boundary_condition(
  3624. x, fill_value, padding=None, pad_fill_value=1.0, dim=None
  3625. ):
  3626. h = x.get_size()[-dim:]
  3627. x_loader = x.make_loader()
  3628. padding_h = padding or [0] * dim
  3629. def load(index):
  3630. prefix = index[:-dim]
  3631. ih = index[-dim:]
  3632. mask = functools.reduce(
  3633. ops.and_,
  3634. [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)],
  3635. )
  3636. return (
  3637. ops.masked(
  3638. mask,
  3639. lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)(
  3640. [*prefix, *ih]
  3641. ),
  3642. fill_value,
  3643. )
  3644. if padding
  3645. else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value)
  3646. )
  3647. return load
  3648. def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None):
  3649. if dilation is None:
  3650. dilation = [1] * len(padding)
  3651. x_out = FloorDiv(
  3652. x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1),
  3653. stride[i],
  3654. )
  3655. if ceil_mode:
  3656. x_alt = FloorDiv(
  3657. x
  3658. + 2 * padding[i]
  3659. - dilation[i] * (kernel_size[i] - 1)
  3660. + 2 * (stride[i] - 1),
  3661. stride[i],
  3662. )
  3663. if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
  3664. # Sliding windows must start within the input or left padding
  3665. x_alt -= 1 # type: ignore[assignment]
  3666. V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
  3667. if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
  3668. # ceil mode is actually a no-op, lets guard on that
  3669. V.graph.sizevars.check_equals(x_out, x_alt)
  3670. ceil_mode = False
  3671. else:
  3672. x_out = x_alt
  3673. return x_out, ceil_mode
  3674. def should_fallback_max_pool_with_indices(kernel_size, *, n_dim):
  3675. kernel_size = pad_listlike(kernel_size, n_dim)
  3676. window_size = functools.reduce(operator.mul, kernel_size)
  3677. return window_size > 25
  3678. def max_pool_checks(
  3679. x, kernel_size, stride, padding, dilation, n_dim, *, assert_fallback=None
  3680. ):
  3681. if padding == 0:
  3682. padding = [0] * n_dim
  3683. if dilation == 1:
  3684. dilation = [1] * n_dim
  3685. if not stride:
  3686. stride = kernel_size
  3687. kernel_size = pad_listlike(kernel_size, n_dim)
  3688. stride = pad_listlike(stride, n_dim)
  3689. padding = pad_listlike(padding, n_dim)
  3690. dilation = pad_listlike(dilation, n_dim)
  3691. assert isinstance(x, TensorBox)
  3692. assert len(kernel_size) == n_dim
  3693. assert len(stride) == n_dim
  3694. assert len(padding) == n_dim
  3695. assert len(dilation) == n_dim
  3696. assert len(x.get_size()) in (n_dim + 1, n_dim + 2)
  3697. use_fallback = should_fallback_max_pool_with_indices(kernel_size, n_dim=n_dim)
  3698. if assert_fallback is not None:
  3699. assert use_fallback == assert_fallback
  3700. return kernel_size, stride, padding, dilation, use_fallback
  3701. def _max_pool_with_offsets(
  3702. x,
  3703. kernel_size,
  3704. stride,
  3705. padding,
  3706. dilation,
  3707. ceil_mode,
  3708. *,
  3709. n_dim,
  3710. ):
  3711. x.realize_hint()
  3712. batch = x.shape[:-n_dim]
  3713. dhw = x.shape[-n_dim:]
  3714. dhw_out, ceil_mode = zip(
  3715. *[
  3716. pooling_size(
  3717. dhw[d], d, kernel_size, stride, padding, ceil_mode, dilation=dilation
  3718. )
  3719. for d in range(n_dim)
  3720. ]
  3721. )
  3722. dtype = x.dtype
  3723. min_value = (
  3724. False
  3725. if dtype is torch.bool
  3726. else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min)
  3727. )
  3728. new_size = list(batch) + list(dhw_out)
  3729. if any(padding) or any(ceil_mode) or any(d > 1 for d in dilation):
  3730. x_loader = constant_boundary_condition(x, min_value, dim=n_dim)
  3731. else:
  3732. x_loader = x.make_loader()
  3733. def fn_inner(idx, reduction_idx):
  3734. prefix = idx[:-n_dim]
  3735. bh = idx[-n_dim:]
  3736. ih = [
  3737. (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
  3738. for i in range(n_dim)
  3739. ]
  3740. return x_loader([*prefix, *ih])
  3741. result = Reduction.create(
  3742. reduction_type="max",
  3743. input_node=x,
  3744. device=x.get_device(),
  3745. dst_dtype=dtype,
  3746. src_dtype=dtype,
  3747. inner_fn=fn_inner,
  3748. ranges=new_size,
  3749. reduction_ranges=kernel_size,
  3750. )
  3751. offsets = Reduction.create(
  3752. reduction_type="argmax",
  3753. input_node=x,
  3754. device=x.get_device(),
  3755. dst_dtype=torch.int64,
  3756. src_dtype=dtype,
  3757. inner_fn=fn_inner,
  3758. ranges=new_size,
  3759. reduction_ranges=kernel_size,
  3760. )
  3761. if isinstance(result.data.data, Reduction): # type: ignore[attr-defined, union-attr]
  3762. # Only realize if reduction isn't unrolled
  3763. result.realize()
  3764. if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined, union-attr]
  3765. # Only realize if reduction isn't unrolled
  3766. offsets.realize()
  3767. return result, offsets
  3768. @register_lowering(prims._low_memory_max_pool_with_offsets, type_promotion_kind=None)
  3769. def _low_memory_max_pool_with_offsets(
  3770. x,
  3771. kernel_size,
  3772. stride,
  3773. padding,
  3774. dilation,
  3775. ceil_mode=False,
  3776. ):
  3777. n_dim = len(kernel_size)
  3778. # assert we are not on a fallback path, the inductor decomp should have guaranteed this
  3779. kernel_size, stride, padding, dilation, _ = max_pool_checks(
  3780. x,
  3781. kernel_size,
  3782. stride,
  3783. padding,
  3784. dilation,
  3785. n_dim,
  3786. assert_fallback=False,
  3787. )
  3788. with config.patch(unroll_reductions_threshold=25):
  3789. result, offsets = _max_pool_with_offsets(
  3790. x,
  3791. kernel_size,
  3792. stride,
  3793. padding,
  3794. dilation,
  3795. ceil_mode,
  3796. n_dim=n_dim,
  3797. )
  3798. return result, to_dtype(offsets, torch.int8)
  3799. def _pool_offsets_to_indices(
  3800. offsets: TensorBox,
  3801. kernel_size: Sequence[Union[int, torch.SymInt]],
  3802. input_size: Sequence[Union[int, torch.SymInt]],
  3803. increments_to_index: Callable[
  3804. [Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]],
  3805. torch._inductor.virtualized.OpsValue,
  3806. ],
  3807. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  3808. n_dim = len(kernel_size)
  3809. offsets_loader = offsets.make_loader()
  3810. window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
  3811. def offsets_to_indices(idx):
  3812. offset = offsets_loader(idx)
  3813. offset_sympy = ops.indirect_indexing(offset, window_size)
  3814. reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
  3815. idhw = increments_to_index(idx, reduction_idx)
  3816. return ops.index_expr(
  3817. inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64
  3818. )
  3819. indices = Pointwise.create(
  3820. device=offsets.get_device(),
  3821. dtype=torch.int64,
  3822. inner_fn=offsets_to_indices,
  3823. ranges=offsets.get_size(),
  3824. )
  3825. return indices
  3826. @register_lowering(
  3827. prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
  3828. )
  3829. def _low_memory_max_pool_offsets_to_indices(
  3830. offsets, kernel_size, input_size, stride, padding, dilation
  3831. ):
  3832. # TODO: Generalize to other max pooling flavors
  3833. n_dim = len(kernel_size)
  3834. def increments_to_index(idx, reduction_idx):
  3835. bh = idx[-n_dim:]
  3836. return [
  3837. (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
  3838. for i in range(n_dim)
  3839. ]
  3840. return _pool_offsets_to_indices(
  3841. offsets, kernel_size, input_size, increments_to_index
  3842. )
  3843. def _max_pool_with_indices(
  3844. x,
  3845. kernel_size,
  3846. stride,
  3847. padding,
  3848. dilation,
  3849. ceil_mode,
  3850. n_dim,
  3851. ):
  3852. kernel_size, stride, padding, dilation, _ = max_pool_checks(
  3853. x, kernel_size, stride, padding, dilation, n_dim=n_dim
  3854. )
  3855. out, offsets = _max_pool_with_offsets(
  3856. x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=n_dim
  3857. )
  3858. indices = _low_memory_max_pool_offsets_to_indices(
  3859. offsets,
  3860. kernel_size,
  3861. x.shape[-n_dim:],
  3862. stride,
  3863. padding,
  3864. dilation,
  3865. )
  3866. return out, indices
  3867. # Fallback when we do not decompose to the low-memory path.
  3868. @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
  3869. def max_pool2d_with_indices(
  3870. x,
  3871. kernel_size,
  3872. stride=None,
  3873. padding=0,
  3874. dilation=1,
  3875. ceil_mode=False,
  3876. ):
  3877. return _max_pool_with_indices(
  3878. x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=2
  3879. )
  3880. # Fallback when we do not decompose to the low-memory path.
  3881. @register_lowering(aten.max_pool3d_with_indices, type_promotion_kind=None)
  3882. def max_pool3d_with_indices(
  3883. x,
  3884. kernel_size,
  3885. stride=None,
  3886. padding=0,
  3887. dilation=1,
  3888. ceil_mode=False,
  3889. ):
  3890. return _max_pool_with_indices(
  3891. x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=3
  3892. )
  3893. fallback_max_pool2d_with_indices_backward = fallback_handler(
  3894. aten.max_pool2d_with_indices_backward.default,
  3895. add_to_fallback_set=False,
  3896. )
  3897. @register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None)
  3898. def max_pool2d_with_indices_backward(
  3899. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  3900. ):
  3901. if padding == 0:
  3902. padding = [0, 0]
  3903. if dilation == 1:
  3904. dilation = [1, 1]
  3905. if not stride:
  3906. stride = kernel_size
  3907. assert isinstance(x, TensorBox)
  3908. assert len(kernel_size) == 2
  3909. assert len(stride) == 2
  3910. assert len(padding) == 2
  3911. assert len(dilation) == 2
  3912. assert len(x.get_size()) in (3, 4)
  3913. # we will read this many times, so make sure it is computed
  3914. grad_output.realize_hint()
  3915. gO_stride = grad_output.maybe_get_stride()
  3916. x_stride: Optional[Sequence[Any]]
  3917. if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined]
  3918. data = x.data.data # type: ignore[attr-defined]
  3919. device = data.get_device()
  3920. assert device is not None
  3921. x_buffer = ir.ComputedBuffer(
  3922. name=None,
  3923. layout=ir.FlexibleLayout(
  3924. device=device,
  3925. dtype=data.get_dtype(),
  3926. size=data.get_size(),
  3927. ),
  3928. data=data,
  3929. )
  3930. x_buffer.decide_layout()
  3931. x_stride = x_buffer.get_stride()
  3932. else:
  3933. x_stride = x.maybe_get_stride()
  3934. is_channels_last = (x_stride is not None and x_stride[1] == 1) or (
  3935. gO_stride is not None and gO_stride[1] == 1
  3936. )
  3937. if any(d != 1 for d in dilation):
  3938. # dilation NYI
  3939. return fallback_max_pool2d_with_indices_backward(
  3940. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  3941. )
  3942. *_batch, _height, width = x.get_size()
  3943. *_, pooled_height, pooled_width = grad_output.get_size()
  3944. indices_loader = indices.make_loader()
  3945. grad_loader = grad_output.make_loader()
  3946. new_size = list(x.get_size())
  3947. h_window_size = max(
  3948. max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
  3949. for h in range(kernel_size[0] * 2)
  3950. )
  3951. w_window_size = max(
  3952. max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
  3953. for w in range(kernel_size[1] * 2)
  3954. )
  3955. window_size = h_window_size * w_window_size
  3956. if window_size > 25:
  3957. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  3958. return fallback_max_pool2d_with_indices_backward(
  3959. grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
  3960. )
  3961. indices_size = indices.get_size()
  3962. def fn(idx):
  3963. *prefix, h, w = idx
  3964. index_test = ops.index_expr(h * width + w, torch.int32)
  3965. h = h + padding[0]
  3966. w = w + padding[1]
  3967. phstart = ops.index_expr(
  3968. FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
  3969. )
  3970. pwstart = ops.index_expr(
  3971. FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
  3972. )
  3973. phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
  3974. pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
  3975. phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
  3976. pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
  3977. phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
  3978. pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
  3979. gradient = None
  3980. for ph_ in range(h_window_size):
  3981. for pw_ in range(w_window_size):
  3982. ph = ops.add(phstart, ops.constant(ph_, torch.int32))
  3983. pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
  3984. grad_index = [
  3985. *prefix,
  3986. ops.indirect_indexing(
  3987. ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))),
  3988. indices_size[-2],
  3989. check=False,
  3990. ),
  3991. ops.indirect_indexing(
  3992. ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))),
  3993. indices_size[-1],
  3994. check=False,
  3995. ),
  3996. ]
  3997. index_actual = indices_loader(grad_index)
  3998. grad_part = grad_loader(grad_index)
  3999. check = ops.eq(index_actual, index_test)
  4000. if gradient is None:
  4001. # don't need mask for 0, 0
  4002. gradient = ops.where(
  4003. check, grad_part, ops.constant(0.0, torch.float32)
  4004. )
  4005. else:
  4006. mask = ops.and_(
  4007. ops.and_(
  4008. ops.lt(ph, phend),
  4009. ops.lt(pw, pwend),
  4010. ),
  4011. check,
  4012. )
  4013. gradient = ops.where(mask, ops.add(gradient, grad_part), gradient)
  4014. assert gradient is not None
  4015. return gradient
  4016. out = Pointwise.create(
  4017. device=grad_output.get_device(),
  4018. dtype=grad_output.get_dtype(),
  4019. inner_fn=fn,
  4020. ranges=new_size,
  4021. )
  4022. if is_channels_last:
  4023. return ir.ExternKernel.require_channels_last(out)
  4024. else:
  4025. return out
  4026. def pad_adaptive_loader(x, pad_val=0.0):
  4027. x_loader = x.make_loader()
  4028. def load(prefix, increments, start_indices, end_indices):
  4029. ih, iw = increments
  4030. h_start_index, w_start_index = start_indices
  4031. h_end_index, w_end_index = end_indices
  4032. mask = ops.and_(
  4033. ops.lt(
  4034. ops.index_expr(h_start_index + ih, torch.int64),
  4035. ops.index_expr(h_end_index, torch.int64),
  4036. ),
  4037. ops.lt(
  4038. ops.index_expr(w_start_index + iw, torch.int64),
  4039. ops.index_expr(w_end_index, torch.int64),
  4040. ),
  4041. )
  4042. return ops.masked(
  4043. mask,
  4044. lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]),
  4045. pad_val,
  4046. )
  4047. return load
  4048. def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
  4049. h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
  4050. h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
  4051. w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
  4052. w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
  4053. return h_start_index, h_end_index, w_start_index, w_end_index
  4054. def _adaptive_pooling_fn(
  4055. start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
  4056. ):
  4057. h_in, w_in = in_sizes
  4058. h_out, w_out = out_sizes
  4059. (
  4060. h_start_index_fn,
  4061. h_end_index_fn,
  4062. w_start_index_fn,
  4063. w_end_index_fn,
  4064. ) = compute_indices_adaptive_pooling(
  4065. start_index, end_index, h_in, w_in, h_out, w_out
  4066. )
  4067. def fn(idx, loader):
  4068. *prefix, bh, bw = idx
  4069. h_start_index = h_start_index_fn(bh)
  4070. h_end_index = h_end_index_fn(bh)
  4071. w_start_index = w_start_index_fn(bw)
  4072. w_end_index = w_end_index_fn(bw)
  4073. result = None
  4074. for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
  4075. val = loader(
  4076. prefix,
  4077. [ih, iw],
  4078. [h_start_index, w_start_index],
  4079. [h_end_index, w_end_index],
  4080. )
  4081. if result is None:
  4082. result = val
  4083. else:
  4084. result = pooling_fn(val, result)
  4085. return result
  4086. return fn
  4087. def _adaptive_pooling_fn_with_idx(
  4088. start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
  4089. ):
  4090. h_in, w_in = in_sizes
  4091. h_out, w_out = out_sizes
  4092. (
  4093. h_start_index_fn,
  4094. h_end_index_fn,
  4095. w_start_index_fn,
  4096. w_end_index_fn,
  4097. ) = compute_indices_adaptive_pooling(
  4098. start_index, end_index, h_in, w_in, h_out, w_out
  4099. )
  4100. def fn(idx, loader):
  4101. *prefix, bh, bw = idx
  4102. h_start_index = h_start_index_fn(bh)
  4103. h_end_index = h_end_index_fn(bh)
  4104. w_start_index = w_start_index_fn(bw)
  4105. w_end_index = w_end_index_fn(bw)
  4106. maxval = None
  4107. maxindex = None
  4108. for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
  4109. val = loader(
  4110. prefix,
  4111. [ih, iw],
  4112. [h_start_index, w_start_index],
  4113. [h_end_index, w_end_index],
  4114. )
  4115. index = ops.index_expr(
  4116. (h_start_index + ih) * w_in + w_start_index + iw, torch.int64
  4117. )
  4118. if maxindex is None:
  4119. maxindex = index
  4120. else:
  4121. maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
  4122. if maxval is None:
  4123. maxval = val
  4124. else:
  4125. maxval = pooling_fn(val, maxval)
  4126. return maxindex
  4127. return fn
  4128. fallback_adaptive_avg_pool2d = fallback_handler(
  4129. aten._adaptive_avg_pool2d.default, add_to_fallback_set=False
  4130. )
  4131. @register_lowering(aten._adaptive_avg_pool2d)
  4132. def _adaptive_avg_pool2d(x, output_size):
  4133. if x.get_dtype() == torch.int64:
  4134. # not supported in eager
  4135. raise RuntimeError("'adaptive_avg_pool2d' not implemented for 'Long'")
  4136. assert isinstance(x, TensorBox)
  4137. assert len(output_size) == 2
  4138. x.realize_hint()
  4139. *batch, h_in, w_in = x.get_size()
  4140. h_in = V.graph.sizevars.guard_int(h_in)
  4141. w_in = V.graph.sizevars.guard_int(w_in)
  4142. h_out, w_out = output_size
  4143. # no-op if the same input and output
  4144. if h_in == h_out and w_in == w_out:
  4145. return clone(x)
  4146. if h_out == 0 or w_out == 0:
  4147. o_size = [*batch, h_out, w_out]
  4148. return empty(o_size, dtype=x.get_dtype(), device=x.get_device())
  4149. if h_in % h_out == 0 and w_in % w_out == 0:
  4150. kernel_size = [h_in // h_out, w_in // w_out]
  4151. return avg_pool2d(x, kernel_size)
  4152. h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
  4153. w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
  4154. new_size = list(batch) + [h_out, w_out]
  4155. dtype = x.get_dtype()
  4156. window_size = h_kernel_max * w_kernel_max
  4157. if window_size > 25:
  4158. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  4159. return fallback_adaptive_avg_pool2d(x, output_size)
  4160. def start_index(index, out_dim, inp_dim):
  4161. return FloorDiv((index * inp_dim), out_dim)
  4162. def end_index(index, out_dim, inp_dim):
  4163. return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
  4164. fn_sum = _adaptive_pooling_fn(
  4165. start_index=start_index,
  4166. end_index=end_index,
  4167. kernel_maxes=[h_kernel_max, w_kernel_max],
  4168. in_sizes=[h_in, w_in],
  4169. out_sizes=[h_out, w_out],
  4170. pooling_fn=ops.add,
  4171. )
  4172. ones_loader = pad_adaptive_loader(ones_like(x))
  4173. def fn(idx):
  4174. return ops.truediv(
  4175. fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader)
  4176. )
  4177. rv = Pointwise.create(
  4178. device=x.get_device(),
  4179. dtype=dtype,
  4180. inner_fn=fn,
  4181. ranges=new_size,
  4182. )
  4183. # TODO: should we force these to be realized?
  4184. return rv
  4185. fallback_adaptive_max_pool2d = fallback_handler(
  4186. aten.adaptive_max_pool2d.default, add_to_fallback_set=False
  4187. )
  4188. @register_lowering(aten.adaptive_max_pool2d)
  4189. def adaptive_max_pool2d(x, output_size):
  4190. if x.get_dtype() == torch.int64:
  4191. # not supported in eager
  4192. raise RuntimeError("adaptive_max_pool2d not implemented for Long")
  4193. assert isinstance(x, TensorBox)
  4194. assert len(output_size) == 2
  4195. x.realize_hint()
  4196. *batch, h_in, w_in = x.get_size()
  4197. h_in = V.graph.sizevars.guard_int(h_in)
  4198. w_in = V.graph.sizevars.guard_int(w_in)
  4199. h_out, w_out = output_size
  4200. if h_out == 0 or w_out == 0:
  4201. o_size = [*batch, h_out, w_out]
  4202. return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty(
  4203. o_size, dtype=torch.int64, device=x.get_device()
  4204. )
  4205. if h_in % h_out == 0 and w_in % w_out == 0:
  4206. # This is handled by a decomposition
  4207. raise ValueError
  4208. h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
  4209. w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
  4210. new_size = list(batch) + [h_out, w_out]
  4211. dtype = x.get_dtype()
  4212. window_size = h_kernel_max * w_kernel_max
  4213. if window_size > 25:
  4214. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  4215. return fallback_adaptive_max_pool2d(x, output_size)
  4216. def start_index(index, out_dim, inp_dim):
  4217. return FloorDiv((index * inp_dim), out_dim)
  4218. def end_index(index, out_dim, inp_dim):
  4219. return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
  4220. inner_func_max_val = _adaptive_pooling_fn(
  4221. start_index=start_index,
  4222. end_index=end_index,
  4223. kernel_maxes=[h_kernel_max, w_kernel_max],
  4224. in_sizes=[h_in, w_in],
  4225. out_sizes=[h_out, w_out],
  4226. pooling_fn=ops.maximum,
  4227. )
  4228. inner_func_max_idx = _adaptive_pooling_fn_with_idx(
  4229. start_index=start_index,
  4230. end_index=end_index,
  4231. kernel_maxes=[h_kernel_max, w_kernel_max],
  4232. in_sizes=[h_in, w_in],
  4233. out_sizes=[h_out, w_out],
  4234. pooling_fn=ops.maximum,
  4235. )
  4236. def inner_fn_max_val(idx):
  4237. return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
  4238. def inner_fn_max_idx(idx):
  4239. return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
  4240. rv = Pointwise.create(
  4241. device=x.get_device(),
  4242. dtype=dtype,
  4243. inner_fn=inner_fn_max_val,
  4244. ranges=new_size,
  4245. )
  4246. ri = Pointwise.create(
  4247. device=x.get_device(),
  4248. dtype=torch.int64,
  4249. inner_fn=inner_fn_max_idx,
  4250. ranges=new_size,
  4251. )
  4252. return rv, ri
  4253. def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
  4254. out_sz = out_sz[dim]
  4255. in_sz = in_sz[dim]
  4256. kernel_sz = kernel_sz[dim]
  4257. samples_loader = samples.make_loader()
  4258. def load(prefix, i):
  4259. # Handle indexing for samples tensor correctly for different input dimensions
  4260. # samples tensor always has shape (N, C, 2) for fractional_max_pool2d where:
  4261. # - N=1 for 3D inputs (C,H,W), N=batch_size for 4D inputs (N,C,H,W)
  4262. # - C=num_channels
  4263. # - 2 for the two spatial dimensions (height, width)
  4264. samples_shape = samples.get_size()
  4265. if len(samples_shape) == 3: # Expected: (N, C, 2)
  4266. if len(prefix) == 1:
  4267. # 3D input case: prefix=(channel,), samples=(1, C, 2)
  4268. # Access: samples[0, channel, dim]
  4269. sample = samples_loader([0, prefix[0], ndims - 1 - dim])
  4270. elif len(prefix) >= 2:
  4271. # 4D+ input case: prefix=(batch, channel, ...), samples=(batch, C, 2)
  4272. # Access: samples[batch, channel, dim]
  4273. sample = samples_loader([prefix[0], prefix[1], ndims - 1 - dim])
  4274. else:
  4275. # Edge case - shouldn't happen for valid fractional pooling
  4276. sample = samples_loader([0, 0, ndims - 1 - dim])
  4277. else:
  4278. # Fallback for unexpected tensor shapes
  4279. sample = samples_loader([*prefix, ndims - 1 - dim])
  4280. i_expr = ops.index_expr(i, samples.get_dtype())
  4281. diff = ops.index_expr(in_sz - kernel_sz, torch.int64)
  4282. out_sz_expr = ops.index_expr(out_sz - 1, torch.int64)
  4283. alpha = ops.truediv(
  4284. ops.to_dtype(diff, torch.float64), ops.to_dtype(out_sz_expr, torch.float64)
  4285. )
  4286. alpha = ops.where(ops.eq(out_sz_expr, 0), 0, alpha)
  4287. seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha)
  4288. seq_i = ops.to_dtype(seq_i, torch.int64)
  4289. mask = ops.lt(i_expr, out_sz_expr)
  4290. return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz))
  4291. return load
  4292. @register_lowering(aten.fractional_max_pool2d)
  4293. def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
  4294. return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2)
  4295. @register_lowering(aten.fractional_max_pool3d)
  4296. def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
  4297. return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3)
  4298. def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim):
  4299. x.realize_hint()
  4300. batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:]
  4301. with config.patch(unroll_reductions_threshold=25):
  4302. dhw_index_fn = [
  4303. _fractional_pooling_offsets(
  4304. samples=random_samples,
  4305. in_sz=inp_dhw,
  4306. out_sz=output_size,
  4307. kernel_sz=kernel_size,
  4308. ndims=n_dim,
  4309. dim=d,
  4310. )
  4311. for d in range(n_dim)
  4312. ]
  4313. x_loader = x.make_loader()
  4314. def fn_inner(idx, reduction_idx):
  4315. prefix = idx[:-n_dim]
  4316. return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
  4317. def increments_to_index(idx, reduction_idx):
  4318. prefix = idx[:-n_dim]
  4319. bdhw = idx[-n_dim:]
  4320. return [
  4321. dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d]
  4322. for d in range(n_dim)
  4323. ]
  4324. new_size = list(batch) + list(output_size)
  4325. dtype = x.get_dtype()
  4326. result = Reduction.create(
  4327. reduction_type="max",
  4328. input_node=x,
  4329. device=x.get_device(),
  4330. dst_dtype=dtype,
  4331. src_dtype=dtype,
  4332. inner_fn=fn_inner,
  4333. ranges=new_size,
  4334. reduction_ranges=kernel_size,
  4335. )
  4336. offsets = Reduction.create(
  4337. reduction_type="argmax",
  4338. input_node=x,
  4339. device=x.get_device(),
  4340. dst_dtype=torch.int64,
  4341. src_dtype=dtype,
  4342. inner_fn=fn_inner,
  4343. ranges=new_size,
  4344. reduction_ranges=kernel_size,
  4345. )
  4346. assert isinstance(result, TensorBox), result
  4347. if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
  4348. # Only realize if reduction isn't unrolled
  4349. result.realize()
  4350. assert isinstance(offsets, TensorBox), offsets
  4351. if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
  4352. # Only realize if reduction isn't unrolled
  4353. offsets.realize()
  4354. indices = _pool_offsets_to_indices(
  4355. offsets, kernel_size, x.shape, increments_to_index
  4356. )
  4357. return result, indices
  4358. @register_lowering(aten.upsample_nearest2d_backward.default)
  4359. def upsample_nearest2d_backward(
  4360. x, output_size=None, input_size=None, scales_h=None, scales_w=None
  4361. ):
  4362. x.realize_hint()
  4363. *_batch, inp_h, inp_w = x.get_size()
  4364. inp_h = V.graph.sizevars.guard_int(inp_h)
  4365. inp_w = V.graph.sizevars.guard_int(inp_w)
  4366. *_batch, out_h, out_w = input_size
  4367. if inp_h % out_h == 0 and inp_w % out_w == 0:
  4368. return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1)
  4369. h_kernel_max = ceildiv(inp_h, out_h)
  4370. w_kernel_max = ceildiv(inp_w, out_w)
  4371. def start_index(index, out_dim, inp_dim):
  4372. return CeilDiv(index * inp_dim, sympy.sympify(out_dim))
  4373. def end_index(index, out_dim, inp_dim):
  4374. return start_index((index + 1), out_dim, inp_dim)
  4375. fn_sum = _adaptive_pooling_fn(
  4376. start_index=start_index,
  4377. end_index=end_index,
  4378. kernel_maxes=[h_kernel_max, w_kernel_max],
  4379. in_sizes=[inp_h, inp_w],
  4380. out_sizes=[out_h, out_w],
  4381. pooling_fn=ops.add,
  4382. )
  4383. def fn(idx):
  4384. return fn_sum(idx, pad_adaptive_loader(x))
  4385. rv = Pointwise.create(
  4386. device=x.get_device(),
  4387. dtype=x.get_dtype(),
  4388. inner_fn=fn,
  4389. ranges=list(input_size),
  4390. )
  4391. return rv
  4392. fallback_avg_pool2d = fallback_handler(
  4393. aten.avg_pool2d.default, add_to_fallback_set=False
  4394. )
  4395. fallback_avg_pool3d = fallback_handler(
  4396. aten.avg_pool3d.default, add_to_fallback_set=False
  4397. )
  4398. @register_lowering(aten.avg_pool2d, type_promotion_kind=None)
  4399. def avg_pool2d(
  4400. x,
  4401. kernel_size,
  4402. stride=(),
  4403. padding=0,
  4404. ceil_mode=False,
  4405. count_include_pad=True,
  4406. divisor_override=None,
  4407. ):
  4408. return _avg_poolnd(
  4409. x,
  4410. kernel_size,
  4411. stride,
  4412. padding,
  4413. ceil_mode,
  4414. count_include_pad,
  4415. divisor_override,
  4416. dim=2,
  4417. )
  4418. @register_lowering(aten.avg_pool3d, type_promotion_kind=None)
  4419. def avg_pool3d(
  4420. x,
  4421. kernel_size,
  4422. stride=(),
  4423. padding=0,
  4424. ceil_mode=False,
  4425. count_include_pad=True,
  4426. divisor_override=None,
  4427. ):
  4428. return _avg_poolnd(
  4429. x,
  4430. kernel_size,
  4431. stride,
  4432. padding,
  4433. ceil_mode,
  4434. count_include_pad,
  4435. divisor_override,
  4436. dim=3,
  4437. )
  4438. def _avg_poolnd(
  4439. x,
  4440. kernel_size,
  4441. stride,
  4442. padding,
  4443. ceil_mode,
  4444. count_include_pad,
  4445. divisor_override,
  4446. dim,
  4447. ):
  4448. if not stride:
  4449. stride = kernel_size
  4450. if not padding:
  4451. padding = [0] * dim
  4452. kernel_size = pad_listlike(kernel_size, dim)
  4453. stride = pad_listlike(stride, dim)
  4454. padding = pad_listlike(padding, dim)
  4455. assert isinstance(x, TensorBox)
  4456. assert len(kernel_size) == dim
  4457. assert len(stride) == dim
  4458. assert len(padding) == dim
  4459. assert len(x.get_size()) in (dim + 1, dim + 2)
  4460. x.realize_hint()
  4461. batch = x.get_size()[:-dim]
  4462. h = x.get_size()[-dim:]
  4463. h_out, ceil_modes = zip(
  4464. *[
  4465. pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode)
  4466. for i in range(dim)
  4467. ]
  4468. )
  4469. if any(padding) or any(ceil_modes):
  4470. x_loader = constant_boundary_condition(x, 0.0, dim=dim)
  4471. had_padding = True
  4472. else:
  4473. x_loader = x.make_loader()
  4474. had_padding = False
  4475. new_size = list(batch) + list(h_out)
  4476. dtype = x.get_dtype()
  4477. window_size = functools.reduce(operator.mul, kernel_size)
  4478. if window_size > 25:
  4479. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  4480. if dim == 2:
  4481. fallback = fallback_avg_pool2d
  4482. elif dim == 3:
  4483. fallback = fallback_avg_pool3d
  4484. else:
  4485. raise ValueError(f"Unknown dim: {dim}")
  4486. return fallback(
  4487. x,
  4488. kernel_size,
  4489. stride,
  4490. padding,
  4491. ceil_mode,
  4492. count_include_pad,
  4493. divisor_override,
  4494. )
  4495. def fn_sum(idx, loader):
  4496. prefix = idx[:-dim]
  4497. b = idx[-dim:]
  4498. total = None
  4499. for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]):
  4500. inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)]
  4501. val = loader([*prefix, *inp])
  4502. if total is None:
  4503. total = val
  4504. else:
  4505. total = ops.add(val, total)
  4506. return total
  4507. if not had_padding or divisor_override:
  4508. divisor = divisor_override if divisor_override else window_size
  4509. if dtype.is_floating_point:
  4510. scale = 1 / divisor
  4511. def fn(idx):
  4512. return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype))
  4513. else:
  4514. def fn(idx):
  4515. # C style integer division as done in native/cpu/AvgPoolKernel.cpp
  4516. return ops.truncdiv(fn_sum(idx, x_loader), ops.constant(divisor, dtype))
  4517. else:
  4518. def fn(idx):
  4519. bh = idx[-dim:]
  4520. divide_factors = []
  4521. for i in range(dim):
  4522. hstart = bh[i] * stride[i] - padding[i]
  4523. hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i])
  4524. if not count_include_pad:
  4525. hstart = sympy.Max(hstart, 0)
  4526. hend = sympy.Min(hend, h[i])
  4527. factor = ops.index_expr(hend - hstart, torch.int32)
  4528. divide_factors.append(factor)
  4529. divide_factor = functools.reduce(ops.mul, divide_factors)
  4530. if dtype.is_floating_point:
  4531. return ops.truediv(fn_sum(idx, x_loader), divide_factor)
  4532. # C style integer division as done in native/cpu/AvgPoolKernel.cpp
  4533. return ops.truncdiv(fn_sum(idx, x_loader), divide_factor)
  4534. rv = Pointwise.create(
  4535. device=x.get_device(),
  4536. dtype=dtype,
  4537. inner_fn=fn,
  4538. ranges=new_size,
  4539. )
  4540. # TODO(jansel): should we force these to be realized?
  4541. return rv
  4542. fallback_avg_pool2d_backward = fallback_handler(
  4543. aten.avg_pool2d_backward.default, add_to_fallback_set=False
  4544. )
  4545. @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
  4546. def avg_pool2d_backward(
  4547. grad_output,
  4548. x,
  4549. kernel_size,
  4550. stride,
  4551. padding,
  4552. ceil_mode,
  4553. count_include_pad,
  4554. divisor_override=None,
  4555. ):
  4556. assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
  4557. if not stride:
  4558. stride = kernel_size
  4559. if not padding:
  4560. padding = [0, 0]
  4561. assert isinstance(grad_output, TensorBox)
  4562. assert isinstance(x, TensorBox)
  4563. assert len(kernel_size) == 2
  4564. assert len(stride) == 2
  4565. assert len(padding) == 2
  4566. assert len(x.get_size()) in (3, 4)
  4567. grad_output.realize_hint() # we will read this many times, so make sure it is computed
  4568. *_, height, width = x.get_size()
  4569. _h_out, ceil_mode1 = pooling_size(
  4570. height, 0, kernel_size, stride, padding, ceil_mode
  4571. )
  4572. _w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode)
  4573. grad_loader = grad_output.make_loader()
  4574. had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2
  4575. *_, pooled_height, pooled_width = grad_output.get_size()
  4576. new_size = list(x.get_size())
  4577. dtype = x.get_dtype()
  4578. h_window_size = max(
  4579. max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
  4580. for h in range(kernel_size[0] * 2)
  4581. )
  4582. w_window_size = max(
  4583. max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
  4584. for w in range(kernel_size[1] * 2)
  4585. )
  4586. window_size = h_window_size * w_window_size
  4587. if window_size > 25:
  4588. # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
  4589. return fallback_avg_pool2d_backward(
  4590. grad_output,
  4591. x,
  4592. kernel_size,
  4593. stride,
  4594. padding,
  4595. ceil_mode,
  4596. count_include_pad,
  4597. divisor_override,
  4598. )
  4599. def compute_pool_size_without_padding(ph, pw):
  4600. """
  4601. This computes the scaling factor that we will divide an element
  4602. by when `count_include_pad=False`
  4603. """
  4604. stride_h = ops.constant(stride[0], torch.int32)
  4605. stride_w = ops.constant(stride[1], torch.int32)
  4606. pad_h = ops.constant(padding[0], torch.int32)
  4607. pad_w = ops.constant(padding[1], torch.int32)
  4608. kernel_h = ops.constant(kernel_size[0], torch.int32)
  4609. kernel_w = ops.constant(kernel_size[1], torch.int32)
  4610. hstart = ops.sub(ops.mul(ph, stride_h), pad_h)
  4611. wstart = ops.sub(ops.mul(pw, stride_w), pad_w)
  4612. hend = ops.minimum(
  4613. ops.add(hstart, kernel_h),
  4614. ops.add(ops.index_expr(height, torch.int32), pad_h),
  4615. )
  4616. wend = ops.minimum(
  4617. ops.add(wstart, kernel_w),
  4618. ops.add(ops.index_expr(width, torch.int32), pad_w),
  4619. )
  4620. hstart = ops.maximum(hstart, ops.constant(0, torch.int32))
  4621. wstart = ops.maximum(wstart, ops.constant(0, torch.int32))
  4622. hend = ops.minimum(hend, ops.index_expr(height, torch.int32))
  4623. wend = ops.minimum(wend, ops.index_expr(width, torch.int32))
  4624. divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart))
  4625. return divide_factor
  4626. def fn(idx):
  4627. *prefix, h, w = idx
  4628. h = h + padding[0]
  4629. w = w + padding[1]
  4630. phstart = ops.index_expr(
  4631. FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
  4632. )
  4633. pwstart = ops.index_expr(
  4634. FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
  4635. )
  4636. phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
  4637. pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
  4638. phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
  4639. pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
  4640. phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
  4641. pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
  4642. gradient = None
  4643. for ph_ in range(h_window_size):
  4644. for pw_ in range(w_window_size):
  4645. ph = ops.add(phstart, ops.constant(ph_, torch.int32))
  4646. pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
  4647. if divisor_override is not None:
  4648. scale = divisor_override
  4649. elif count_include_pad or not had_padding:
  4650. scale = kernel_size[0] * kernel_size[1]
  4651. else:
  4652. scale = compute_pool_size_without_padding(ph, pw)
  4653. part = ops.truediv(
  4654. grad_loader(
  4655. [
  4656. *prefix,
  4657. ops.indirect_indexing(
  4658. ops.minimum(
  4659. ph, ops.sub(phend, ops.constant(1, torch.int32))
  4660. ),
  4661. pooled_height,
  4662. check=False,
  4663. ),
  4664. ops.indirect_indexing(
  4665. ops.minimum(
  4666. pw, ops.sub(pwend, ops.constant(1, torch.int32))
  4667. ),
  4668. pooled_width,
  4669. check=False,
  4670. ),
  4671. ]
  4672. ),
  4673. scale,
  4674. )
  4675. mask = ops.and_(
  4676. ops.lt(ph, phend),
  4677. ops.lt(pw, pwend),
  4678. )
  4679. if gradient is None:
  4680. gradient = ops.where(mask, part, ops.constant(0.0, torch.float32))
  4681. else:
  4682. gradient = ops.where(mask, ops.add(gradient, part), gradient)
  4683. assert gradient is not None
  4684. return gradient
  4685. rv = Pointwise.create(
  4686. device=grad_output.get_device(),
  4687. dtype=dtype,
  4688. inner_fn=fn,
  4689. ranges=new_size,
  4690. )
  4691. return rv
  4692. fallback_avg_pool3d_backward = fallback_handler(
  4693. aten.avg_pool3d_backward.default, add_to_fallback_set=False
  4694. )
  4695. @register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
  4696. def avg_pool3d_backward(
  4697. grad_output,
  4698. x,
  4699. kernel_size,
  4700. stride,
  4701. padding,
  4702. ceil_mode,
  4703. count_include_pad,
  4704. divisor_override=None,
  4705. ):
  4706. assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
  4707. if not stride:
  4708. stride = kernel_size
  4709. if not padding:
  4710. padding = [0, 0, 0]
  4711. assert isinstance(grad_output, TensorBox)
  4712. assert isinstance(x, TensorBox)
  4713. assert len(kernel_size) == 3
  4714. assert len(stride) == 3
  4715. assert len(padding) == 3
  4716. assert len(x.get_size()) in (4, 5)
  4717. grad_output.realize_hint()
  4718. *_batch, depth, height, width = x.get_size()
  4719. _d_out, ceil_mode_d = pooling_size(
  4720. depth, 0, kernel_size, stride, padding, ceil_mode
  4721. )
  4722. _h_out, ceil_mode_h = pooling_size(
  4723. height, 1, kernel_size, stride, padding, ceil_mode
  4724. )
  4725. _w_out, ceil_mode_w = pooling_size(
  4726. width, 2, kernel_size, stride, padding, ceil_mode
  4727. )
  4728. grad_loader = grad_output.make_loader()
  4729. had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
  4730. *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
  4731. new_size = list(x.get_size())
  4732. dtype = x.get_dtype()
  4733. d_window_size, h_window_size, w_window_size = (
  4734. max(
  4735. max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
  4736. for d in range(kernel_size[i] * 2)
  4737. )
  4738. for i in range(3)
  4739. )
  4740. window_size = d_window_size * h_window_size * w_window_size
  4741. if window_size > 125:
  4742. # Kernel size too big. Results in hard-to-optimize Triton code.
  4743. return fallback_avg_pool3d_backward(
  4744. grad_output,
  4745. x,
  4746. kernel_size,
  4747. stride,
  4748. padding,
  4749. ceil_mode,
  4750. count_include_pad,
  4751. divisor_override,
  4752. )
  4753. def compute_pool_size_without_padding(pd, ph, pw):
  4754. stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
  4755. pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
  4756. kernel_d, kernel_h, kernel_w = (
  4757. ops.constant(k, torch.int32) for k in kernel_size
  4758. )
  4759. dstart, hstart, wstart = (
  4760. ops.sub(ops.mul(p, s), pad)
  4761. for p, s, pad in zip(
  4762. [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
  4763. )
  4764. )
  4765. dend, hend, wend = (
  4766. ops.minimum(
  4767. ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
  4768. )
  4769. for start, k, dim, pad in zip(
  4770. [dstart, hstart, wstart],
  4771. [kernel_d, kernel_h, kernel_w],
  4772. [depth, height, width],
  4773. [pad_d, pad_h, pad_w],
  4774. )
  4775. )
  4776. dstart, hstart, wstart = (
  4777. ops.maximum(start, ops.constant(0, torch.int32))
  4778. for start in [dstart, hstart, wstart]
  4779. )
  4780. dend, hend, wend = (
  4781. ops.minimum(end, ops.index_expr(dim, torch.int32))
  4782. for end, dim in zip([dend, hend, wend], [depth, height, width])
  4783. )
  4784. divide_factor = ops.mul(
  4785. ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
  4786. )
  4787. return divide_factor
  4788. def fn(idx):
  4789. *prefix, d, h, w = idx
  4790. d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
  4791. pdstart, phstart, pwstart = (
  4792. ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
  4793. for v, k, s in zip([d, h, w], kernel_size, stride)
  4794. )
  4795. pdend, phend, pwend = (
  4796. ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
  4797. for v, s in zip([d, h, w], stride)
  4798. )
  4799. pdstart, phstart, pwstart = (
  4800. ops.maximum(pstart, ops.constant(0, torch.int32))
  4801. for pstart in [pdstart, phstart, pwstart]
  4802. )
  4803. pdend, phend, pwend = (
  4804. ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
  4805. for pend, pooled_dim in zip(
  4806. [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
  4807. )
  4808. )
  4809. gradient = None
  4810. # Iterate over the 3D region to accumulate gradients
  4811. for pd_ in range(d_window_size):
  4812. for ph_ in range(h_window_size):
  4813. for pw_ in range(w_window_size):
  4814. pd, ph, pw = (
  4815. ops.add(pstart, ops.constant(p_, torch.int32))
  4816. for pstart, p_ in zip(
  4817. [pdstart, phstart, pwstart], [pd_, ph_, pw_]
  4818. )
  4819. )
  4820. if divisor_override is not None:
  4821. scale = divisor_override
  4822. elif count_include_pad or not had_padding:
  4823. scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
  4824. else:
  4825. scale = compute_pool_size_without_padding(pd, ph, pw)
  4826. part = ops.truediv(
  4827. grad_loader(
  4828. [
  4829. *prefix,
  4830. ops.indirect_indexing(
  4831. ops.minimum(
  4832. pd, ops.sub(pdend, ops.constant(1, torch.int32))
  4833. ),
  4834. pooled_depth,
  4835. check=False,
  4836. ),
  4837. ops.indirect_indexing(
  4838. ops.minimum(
  4839. ph, ops.sub(phend, ops.constant(1, torch.int32))
  4840. ),
  4841. pooled_height,
  4842. check=False,
  4843. ),
  4844. ops.indirect_indexing(
  4845. ops.minimum(
  4846. pw, ops.sub(pwend, ops.constant(1, torch.int32))
  4847. ),
  4848. pooled_width,
  4849. check=False,
  4850. ),
  4851. ]
  4852. ),
  4853. scale,
  4854. )
  4855. mask = ops.and_(
  4856. ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
  4857. ops.lt(pw, pwend),
  4858. )
  4859. if gradient is None:
  4860. gradient = ops.where(
  4861. mask, part, ops.constant(0.0, torch.float32)
  4862. )
  4863. else:
  4864. gradient = ops.where(mask, ops.add(gradient, part), gradient)
  4865. assert gradient is not None
  4866. return gradient
  4867. rv = Pointwise.create(
  4868. device=grad_output.get_device(),
  4869. dtype=dtype,
  4870. inner_fn=fn,
  4871. ranges=new_size,
  4872. )
  4873. return rv
  4874. def _validate_reduction_axis(x, axis):
  4875. size = x.get_size()
  4876. if isinstance(axis, int):
  4877. axis = [axis]
  4878. elif not axis:
  4879. axis = range(len(size))
  4880. if len(size) == 0:
  4881. assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}"
  4882. return []
  4883. axis = list(axis)
  4884. for i in range(len(axis)):
  4885. if axis[i] < 0:
  4886. axis[i] += len(size) if len(size) else 1
  4887. assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0)
  4888. assert len(OrderedSet(axis)) == len(axis), "reduction axis not unique"
  4889. return axis
  4890. def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype):
  4891. if dtype is not None:
  4892. x = to_dtype(x, dtype)
  4893. size = x.get_size()
  4894. axis = OrderedSet[int](_validate_reduction_axis(x, axis))
  4895. kept_sizes = []
  4896. kept_idx = []
  4897. reduced_sizes = []
  4898. reduced_idx = []
  4899. for i in range(len(size)):
  4900. if i in axis:
  4901. reduced_idx.append(i)
  4902. reduced_sizes.append(size[i])
  4903. else:
  4904. kept_idx.append(i)
  4905. kept_sizes.append(size[i])
  4906. def loader(index, reduction_index):
  4907. assert len(reduction_index) == len(reduced_idx)
  4908. if keepdims:
  4909. assert len(index) == len(size)
  4910. index = [index[i] for i in kept_idx]
  4911. assert len(index) == len(kept_idx)
  4912. new_index = [None] * (len(index) + len(reduction_index))
  4913. for idx, var in itertools.chain(
  4914. zip(kept_idx, index), zip(reduced_idx, reduction_index)
  4915. ):
  4916. new_index[idx] = var
  4917. return inner_loader(new_index)
  4918. if keepdims:
  4919. new_size = list(size)
  4920. for i in reduced_idx:
  4921. new_size[i] = sympy.S.One
  4922. else:
  4923. new_size = kept_sizes
  4924. inner_loader = x.make_loader()
  4925. return dict(
  4926. device=x.get_device(),
  4927. dst_dtype=override_return_dtype or x.get_dtype(),
  4928. src_dtype=x.get_dtype(),
  4929. inner_fn=loader,
  4930. ranges=new_size,
  4931. reduction_ranges=reduced_sizes,
  4932. )
  4933. def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
  4934. def inner(x, axis=None, keepdims=False, *, dtype=None):
  4935. kwargs = _make_reduction_inner(
  4936. x,
  4937. axis=axis,
  4938. keepdims=keepdims,
  4939. dtype=dtype,
  4940. override_return_dtype=override_return_dtype,
  4941. )
  4942. result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
  4943. if isinstance(
  4944. result.data.data, # type: ignore[attr-defined, attr-type, union-attr]
  4945. Reduction,
  4946. ): # Only realize if reduction isn't unrolled
  4947. result.realize()
  4948. return result
  4949. return inner
  4950. def _make_scan_inner(x, *, axis, dtype):
  4951. if dtype is not None:
  4952. x = to_dtype(x, dtype)
  4953. axis = _validate_dim(x, axis)
  4954. return dict(
  4955. device=x.get_device(),
  4956. dtypes=(x.get_dtype(),),
  4957. inner_fns=(x.make_loader(),),
  4958. size=x.get_size(),
  4959. axis=axis,
  4960. )
  4961. @register_lowering(aten.mean)
  4962. def mean(x, axis=None, keepdim=False, *, dtype=None):
  4963. if dtype is not None:
  4964. x = to_dtype(x, dtype)
  4965. size = x.get_size()
  4966. axis = _validate_reduction_axis(x, axis)
  4967. # compute in higher-precision until end of mean lowering
  4968. output_dtype = x.get_dtype()
  4969. if output_dtype in (torch.float16, torch.bfloat16):
  4970. x = to_dtype(x, torch.float)
  4971. sum_result = sum_(x, axis, keepdim)
  4972. denom = sympy_product(size[i] for i in axis)
  4973. denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
  4974. denom = ExpandView.create(denom, list(sum_result.get_size()))
  4975. return to_dtype(div(sum_result, denom), output_dtype)
  4976. def var_mean_sum_(x, axis, correction, keepdim, return_mean):
  4977. if correction is None:
  4978. correction = 1
  4979. size = x.get_size()
  4980. axis = _validate_reduction_axis(x, axis)
  4981. x_mean = mean(x, axis, keepdim=True)
  4982. if return_mean:
  4983. x_mean.realize()
  4984. diffs = square(sub(x, x_mean))
  4985. sum_result = sum_(diffs, axis, keepdim)
  4986. denom = sympy_product(size[i] for i in axis)
  4987. if correction:
  4988. denom = sympy.Max(denom - correction, 0)
  4989. denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
  4990. denom = ExpandView.create(denom, list(sum_result.get_size()))
  4991. x_var = div(sum_result, denom)
  4992. if not return_mean:
  4993. return (x_var,)
  4994. x_mean = x_mean if keepdim else squeeze(x_mean, axis)
  4995. return x_var, x_mean
  4996. def use_two_step_variance(x, axis, keepdim):
  4997. # Instead of unrolling welford, just unroll the simpler two-step var
  4998. axis = _validate_reduction_axis(x, axis)
  4999. kwargs = _make_reduction_inner(
  5000. x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
  5001. )
  5002. ranges = kwargs["ranges"]
  5003. reduction_numel = sympy_product(kwargs["reduction_ranges"])
  5004. return (
  5005. isinstance(reduction_numel, sympy.Integer)
  5006. and int(reduction_numel) < config.unroll_reductions_threshold
  5007. and sympy_product(ranges) != 1
  5008. )
  5009. def var_mean_welford_(x, axis, *, correction, keepdim, return_mean):
  5010. if correction is None:
  5011. correction = 1
  5012. kwargs = _make_reduction_inner(
  5013. x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
  5014. )
  5015. loader = kwargs.pop("inner_fn")
  5016. kwargs.pop("dst_dtype")
  5017. kwargs.pop("src_dtype")
  5018. mean, m2, _ = ir.WelfordReduction.create(
  5019. inner_fns=(loader,),
  5020. reduction_type="welford_reduce",
  5021. dtype=x.get_dtype(),
  5022. **kwargs,
  5023. )
  5024. m2.realize()
  5025. dtype = x.get_dtype()
  5026. size = x.get_size()
  5027. axis = _validate_reduction_axis(x, axis)
  5028. rnumel = sympy_product(size[i] for i in axis)
  5029. def get_constant_or_index_expr(x, dtype):
  5030. if isinstance(x, sympy.Expr) and not x.is_number:
  5031. return ops.to_dtype(ops.index_expr(x, torch.int64), dtype)
  5032. return ops.constant(x, dtype)
  5033. def scale_fn(data):
  5034. c = get_constant_or_index_expr(correction, dtype)
  5035. N = get_constant_or_index_expr(rnumel, dtype)
  5036. zero = ops.constant(0, dtype)
  5037. return data / ops.maximum(zero, N - c)
  5038. var = make_pointwise(scale_fn)(m2)
  5039. if return_mean:
  5040. mean.realize()
  5041. return var, mean
  5042. return (var,)
  5043. def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
  5044. out_dtype = x.get_dtype()
  5045. compute_dtype = get_computation_dtype(out_dtype)
  5046. x = to_dtype(x, compute_dtype, copy=False)
  5047. kwargs = dict(
  5048. x=x,
  5049. axis=axis,
  5050. correction=correction,
  5051. keepdim=keepdim,
  5052. return_mean=return_mean,
  5053. )
  5054. output = (
  5055. var_mean_sum_(**kwargs)
  5056. if use_two_step_variance(x, axis=axis, keepdim=keepdim)
  5057. else var_mean_welford_(**kwargs)
  5058. )
  5059. output = tuple(to_dtype(x, out_dtype, copy=False) for x in output)
  5060. return output[0] if not return_mean else output
  5061. @register_lowering([aten.var, prims.var])
  5062. def var_(x, axis=None, *, correction=None, keepdim=False):
  5063. return var_mean_helper_(
  5064. x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
  5065. )
  5066. @register_lowering(aten.var_mean)
  5067. def var_mean(x, axis=None, *, correction=None, keepdim=False):
  5068. return var_mean_helper_(
  5069. x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
  5070. )
  5071. def pow_recursive(x, y, dtype):
  5072. if y < 0:
  5073. return pow_recursive(ops.reciprocal(x), -y, dtype)
  5074. if y == 0:
  5075. return ops.constant(1, dtype)
  5076. if y == 1:
  5077. return x
  5078. result = pow_recursive(x, y // 2, dtype)
  5079. result = ops.mul(result, result)
  5080. if (y % 2) == 1:
  5081. result = ops.mul(result, x)
  5082. return result
  5083. @make_pointwise
  5084. def pow_native(a, b):
  5085. return ops.pow(a, b)
  5086. fallback_pow_tensor_tensor = fallback_handler(
  5087. aten.pow.Tensor_Tensor, add_to_fallback_set=False
  5088. )
  5089. fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False)
  5090. fallback_pow_tensor_scalar = fallback_handler(
  5091. aten.pow.Tensor_Scalar, add_to_fallback_set=False
  5092. )
  5093. @register_lowering(aten.pow, broadcast=True)
  5094. def pow(a, b):
  5095. if isinstance(b, float) and b == int(b):
  5096. return pow(a, int(b))
  5097. elif isinstance(b, float) and b == 0.5:
  5098. return sqrt(a)
  5099. elif isinstance(b, int) and b == 1:
  5100. return clone(a)
  5101. # Type promotion ensures all tensor arguments have the same type
  5102. dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox))
  5103. is_integer_pow = is_integer_dtype(dtype)
  5104. # Optimize away small fixed powers, or for integers avoid falling back to ATen
  5105. embed_exponent = isinstance(b, int) and (
  5106. -32 < b < 32 or (is_integer_pow and b >= 0)
  5107. )
  5108. if embed_exponent:
  5109. loader = a.make_loader()
  5110. def fn(idx):
  5111. return pow_recursive(loader(idx), b, a.get_dtype())
  5112. return Pointwise.create(
  5113. device=a.get_device(),
  5114. dtype=a.get_dtype(),
  5115. inner_fn=fn,
  5116. ranges=a.get_size(),
  5117. )
  5118. if isinstance(a, Number):
  5119. if a == 1:
  5120. return full_like(b, 1)
  5121. if a == 2 and is_float_dtype(b.get_dtype()):
  5122. return exp2(b)
  5123. if is_integer_pow:
  5124. # ops.pow doesn't work for integers
  5125. if isinstance(a, Number):
  5126. return fallback_pow_scalar(a, b)
  5127. elif isinstance(b, Number):
  5128. return fallback_pow_tensor_scalar(a, b)
  5129. else:
  5130. return fallback_pow_tensor_tensor(a, b)
  5131. return pow_native(a, b)
  5132. def mutate_to(changed, val, unsafe_alias=False):
  5133. if isinstance(changed, TensorBox):
  5134. changed_data = changed.data
  5135. else:
  5136. changed_data = changed
  5137. if isinstance(val, TensorBox):
  5138. val = val.data
  5139. if not isinstance(val, ir.StorageBox):
  5140. # introduce a copy to handle views
  5141. node = Pointwise.create(
  5142. device=changed.get_device(),
  5143. dtype=changed.get_dtype(),
  5144. inner_fn=val.make_loader(),
  5145. ranges=changed.get_size(),
  5146. )
  5147. assert isinstance(node, (BaseView, MutableBox))
  5148. val = node.data
  5149. assert isinstance(val, ir.StorageBox)
  5150. if isinstance(changed_data, ir.StorageBox) and not (
  5151. changed_data.is_input_buffer()
  5152. # In AOTI, module parameters and buffers are not lifted as graph inputs
  5153. or changed_data.is_module_buffer()
  5154. or isinstance(changed_data.data, ir.NopKernel)
  5155. ):
  5156. # Fast path, just swing the data pointer
  5157. val.realize()
  5158. changed_data.data = val.data
  5159. return changed
  5160. ir.MutationLayoutSHOULDREMOVE.realize_into(
  5161. val, changed_data, unsafe_alias=unsafe_alias
  5162. )
  5163. return changed
  5164. @register_lowering(aten.fill_)
  5165. def fill_(x, fill_value):
  5166. return mutate_to(x, full_like(x, fill_value))
  5167. @register_lowering(aten.copy_, type_promotion_kind=None)
  5168. def copy_(dst, src, non_blocking=False):
  5169. if dst is src:
  5170. # dst.copy_(dst) can happen from the reinplacing pass
  5171. return dst
  5172. src = to_device(src, dst.get_device())
  5173. src = to_dtype(src, dst.get_dtype())
  5174. src = expand(src, dst.get_size())
  5175. return mutate_to(dst, src)
  5176. @make_pointwise
  5177. def floordiv(a, b):
  5178. return ops.floordiv(a, b)
  5179. @make_pointwise
  5180. def truncdiv(a, b):
  5181. return ops.truncdiv(a, b)
  5182. @register_lowering(aten.div, broadcast=True)
  5183. def div_mode(a, b, rounding_mode=None):
  5184. both_integer = is_integer_type(a) and is_integer_type(b)
  5185. both_boolean = is_boolean_type(a) and is_boolean_type(b)
  5186. # floordiv and truncdiv need special handling for integer tensors on Triton,
  5187. # see the discussion at https://github.com/triton-lang/triton/issues/605
  5188. if rounding_mode == "floor":
  5189. assert not both_boolean, "floordiv operands can not be boolean at the same time"
  5190. return floordiv(a, b) if both_integer else floor(div(a, b))
  5191. if rounding_mode == "trunc":
  5192. assert not both_boolean, "truncdiv operands can not be boolean at the same time"
  5193. return truncdiv(a, b) if both_integer else trunc(div(a, b))
  5194. return div(a, b)
  5195. @register_lowering([aten.mul], broadcast=True)
  5196. def mul(a, b):
  5197. both_bool = is_boolean_type(a) and is_boolean_type(b)
  5198. if both_bool:
  5199. return logical_and(a, b)
  5200. else:
  5201. fn = ops_wrapper(aten.mul.__name__)
  5202. return make_pointwise(fn)(a, b)
  5203. def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
  5204. """Try convert an arbitrary IR node into an ir.Constant value"""
  5205. # First try unwrapping the IRNode to see if it is already an ir.Constant
  5206. # Optional step, but avoids unnecessary inner_fn evaluation.
  5207. if isinstance(x, ir.MutableBox):
  5208. return get_constant_value(x.data)
  5209. if isinstance(x, ir.BaseView):
  5210. return get_constant_value(x.unwrap_view())
  5211. if isinstance(x, ir.Constant):
  5212. return x
  5213. # If the unwrapped node is not an ir.Constant, try evaluating inner_fn
  5214. # to see if the returned value is from an `ops.constant` call
  5215. if not isinstance(x, ir.Loops):
  5216. return None
  5217. handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
  5218. with (
  5219. V.set_ops_handler(handler),
  5220. patch.object(ir.FlexibleLayout, "allow_indexing", True),
  5221. ):
  5222. out = x.inner_fn(*x.inner_fn_args())
  5223. assert isinstance(out, torch._inductor.virtualized.OpsValue)
  5224. if isinstance(out.value, ir.Constant):
  5225. return out.value
  5226. return None
  5227. # NOTE: prims.div maps to a / b in C, so performs truncation division on
  5228. # integer inputs and true division for floating and complex inputs.
  5229. @register_lowering([prims.div], broadcast=True)
  5230. def div_prim(a, b):
  5231. is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b])
  5232. if is_integral:
  5233. return truncdiv(a, b)
  5234. # Disable CPU optimization to avoid precision issues.
  5235. # see https://github.com/pytorch/pytorch/issues/157959
  5236. if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu":
  5237. # Replace divide by constant with multiply by reciprocal
  5238. if divisor.value == 0:
  5239. reciprocal = math.copysign(float("inf"), divisor.value)
  5240. else:
  5241. reciprocal = 1.0 / divisor.value
  5242. return mul(a, reciprocal)
  5243. def fn(*args):
  5244. return ops.truediv(*args)
  5245. return make_pointwise(fn)(a, b)
  5246. @register_lowering(
  5247. [aten.true_divide, aten.div.Tensor],
  5248. broadcast=True,
  5249. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  5250. )
  5251. def div(a, b):
  5252. a, b = promote_constants(
  5253. (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  5254. )
  5255. return div_prim(a, b)
  5256. @register_lowering([aten.fmod, prims.fmod], broadcast=True)
  5257. def fmod(a, b):
  5258. is_integral = is_boolean_type(a) or is_integer_type(a)
  5259. if is_integral:
  5260. def fn(a, b):
  5261. return ops.mod(a, b)
  5262. else:
  5263. def fn(a, b):
  5264. return ops.fmod(a, b)
  5265. return make_pointwise(fn)(a, b)
  5266. @register_lowering([aten.sum, prims.sum])
  5267. def sum_(x, axis=None, keepdims=False, *, dtype=None):
  5268. if (
  5269. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  5270. ) and dtype is None:
  5271. dtype = torch.int64
  5272. fn = make_reduction("sum", override_return_dtype=dtype)
  5273. return fn(x, axis, keepdims, dtype=dtype)
  5274. fallback_cumsum = fallback_handler(aten.cumsum.default)
  5275. fallback_cumprod = fallback_handler(aten.cumprod.default)
  5276. fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default)
  5277. fallback_cummax = fallback_handler(aten.cummax.default)
  5278. fallback_cummin = fallback_handler(aten.cummin.default)
  5279. @register_lowering(aten.cumsum)
  5280. def cumsum(x, axis=None, dtype=None):
  5281. if (
  5282. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  5283. ) and dtype is None:
  5284. dtype = torch.int64
  5285. if len(x.get_size()) == 0:
  5286. assert axis in [0, -1]
  5287. dtype = dtype or x.get_dtype()
  5288. return to_dtype(x, dtype, copy=True)
  5289. def combine_fn(a_tuple, b_tuple):
  5290. (a,) = a_tuple
  5291. (b,) = b_tuple
  5292. return (ops.add(a, b),)
  5293. kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
  5294. (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
  5295. if result is None:
  5296. return fallback_cumsum(x, dim=axis, dtype=dtype)
  5297. return result
  5298. @register_lowering(aten.cumprod)
  5299. def cumprod(x, axis=None, dtype=None):
  5300. if (
  5301. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  5302. ) and dtype is None:
  5303. dtype = torch.int64
  5304. if len(x.get_size()) == 0:
  5305. assert axis in [0, -1]
  5306. dtype = dtype or x.get_dtype()
  5307. return to_dtype(x, dtype, copy=True)
  5308. def combine_fn(a_tuple, b_tuple):
  5309. (a,) = a_tuple
  5310. (b,) = b_tuple
  5311. return (ops.mul(a, b),)
  5312. kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
  5313. (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
  5314. if result is None:
  5315. return fallback_cumprod(x, dim=axis, dtype=dtype)
  5316. return result
  5317. @register_lowering(aten.logcumsumexp)
  5318. def logcumsumexp(x, dim):
  5319. def log_add_exp_helper(a_tuple, b_tuple):
  5320. (a,) = a_tuple
  5321. (b,) = b_tuple
  5322. min_v = ops.minimum(a, b)
  5323. max_v = ops.maximum(a, b)
  5324. mask = (min_v != max_v) | (~ops.isinf(min_v))
  5325. return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),)
  5326. dtype = x.get_dtype()
  5327. if len(x.get_size()) == 0:
  5328. assert dim in [0, -1]
  5329. return clone(x)
  5330. kwargs = _make_scan_inner(x, axis=dim, dtype=dtype)
  5331. (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper)
  5332. if result is None:
  5333. return fallback_logcumsumexp(x, dim=dim)
  5334. return result
  5335. @register_lowering(aten.cummax, type_promotion_kind=None)
  5336. def cummax(x, axis=None):
  5337. if len(x.get_size()) == 0:
  5338. assert axis in [0, -1]
  5339. return clone(x), empty_like(x, dtype=torch.int64)
  5340. dtype = x.get_dtype()
  5341. combine_fn = ir.get_reduction_combine_fn(
  5342. "argmax", dtype=dtype, arg_break_ties_left=False
  5343. )
  5344. kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
  5345. kwargs["dtypes"] = (dtype, torch.int64)
  5346. kwargs["inner_fns"] = (
  5347. x.make_loader(),
  5348. lambda idx: ops.index_expr(idx[axis], torch.int64),
  5349. )
  5350. values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type]
  5351. if values is None:
  5352. return fallback_cummax(x, dim=axis)
  5353. return values, indices
  5354. @register_lowering(aten.cummin, type_promotion_kind=None)
  5355. def cummin(x, axis=None):
  5356. if len(x.get_size()) == 0:
  5357. assert axis in [0, -1]
  5358. return clone(x), empty_like(x, dtype=torch.int64)
  5359. dtype = x.get_dtype()
  5360. combine_fn = ir.get_reduction_combine_fn(
  5361. "argmin", dtype=dtype, arg_break_ties_left=False
  5362. )
  5363. kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
  5364. kwargs["dtypes"] = (dtype, torch.int64)
  5365. kwargs["inner_fns"] = (
  5366. x.make_loader(),
  5367. lambda idx: ops.index_expr(idx[axis], torch.int64),
  5368. )
  5369. values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type]
  5370. if values is None:
  5371. return fallback_cummin(x, dim=axis)
  5372. return values, indices
  5373. @register_lowering(aten.prod)
  5374. def prod(x, axis=None, keepdims=False, *, dtype=None):
  5375. if (
  5376. is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
  5377. ) and dtype is None:
  5378. dtype = torch.int64
  5379. fn = make_reduction("prod", override_return_dtype=dtype)
  5380. return fn(x, axis, keepdims, dtype=dtype)
  5381. @register_lowering(aten.any)
  5382. def reduce_any(x, dim=None, keepdim=False):
  5383. x = to_dtype(x, torch.bool)
  5384. return make_reduction("any")(x, axis=dim, keepdims=keepdim)
  5385. @register_lowering(aten.max, type_promotion_kind=None)
  5386. def reduce_max(x, dim=None, keepdim=False):
  5387. if dim is not None:
  5388. return (
  5389. reduce_amax(x, axis=dim, keepdims=keepdim),
  5390. reduce_argmax(x, axis=dim, keepdims=keepdim),
  5391. )
  5392. return reduce_amax(x, axis=None, keepdims=keepdim)
  5393. @register_lowering(aten.min, type_promotion_kind=None)
  5394. def reduce_min(x, dim=None, keepdim=False):
  5395. if dim is not None:
  5396. return (
  5397. reduce_amin(x, axis=dim, keepdims=keepdim),
  5398. reduce_argmin(x, axis=dim, keepdims=keepdim),
  5399. )
  5400. return reduce_amin(x, axis=None, keepdims=keepdim)
  5401. register_lowering(prims.xor_sum)(make_reduction("xor_sum"))
  5402. reduce_amax = register_lowering(aten.amax)(make_reduction("max"))
  5403. reduce_amin = register_lowering(aten.amin)(make_reduction("min"))
  5404. reduce_argmax = register_lowering(aten.argmax)(
  5405. make_reduction("argmax", override_return_dtype=torch.int64)
  5406. )
  5407. reduce_argmin = register_lowering(aten.argmin)(
  5408. make_reduction("argmin", override_return_dtype=torch.int64)
  5409. )
  5410. add = register_pointwise(
  5411. aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or"
  5412. )
  5413. sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False)
  5414. @register_lowering(aten.sort.stable, type_promotion_kind=None)
  5415. def sort_stable(x, *, stable=None, dim=-1, descending=False):
  5416. if stable is None:
  5417. stable = False
  5418. shape = x.get_size()
  5419. device = x.get_device()
  5420. dim = canonicalize_dim(len(shape), dim)
  5421. if len(shape) == 0:
  5422. return clone(x), _full(0, device, torch.int64, shape)
  5423. dim_size = shape[dim] if len(shape) else 1
  5424. if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max):
  5425. return sort_fallback(x, stable=stable, dim=dim, descending=descending)
  5426. indices = iota(
  5427. dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False
  5428. )
  5429. view_shape = [1] * len(shape)
  5430. if len(shape):
  5431. view_shape[dim] = dim_size
  5432. indices = view(indices, view_shape)
  5433. indices = expand(indices, shape)
  5434. values, indices = ir.Sort.create(
  5435. device=device,
  5436. dtypes=(x.dtype, indices.dtype),
  5437. inner_fns=(x.make_loader(), indices.make_loader()),
  5438. size=shape,
  5439. axis=dim,
  5440. stable=stable,
  5441. descending=descending,
  5442. )
  5443. if values is None:
  5444. return sort_fallback(x, stable=stable, dim=dim, descending=descending)
  5445. assert indices is not None
  5446. return values, to_dtype(indices, torch.int64)
  5447. @register_lowering(aten.sort.default, type_promotion_kind=None)
  5448. def sort(x, dim=-1, descending=False):
  5449. return sort_stable(x, stable=False, dim=dim, descending=descending)
  5450. def register_pointwise_numeric(op, name=None, triton_fallback=None):
  5451. return register_pointwise(
  5452. op,
  5453. name=name,
  5454. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  5455. triton_fallback=triton_fallback,
  5456. )
  5457. def register_pointwise_numeric_ldf64(op: torch._ops.OpOverloadPacket):
  5458. register_op_requires_libdevice_fp64(op.__name__)
  5459. return register_pointwise(
  5460. op,
  5461. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  5462. )
  5463. rsqrt = register_pointwise_numeric(aten.rsqrt)
  5464. exp = register_pointwise_numeric_ldf64(aten.exp)
  5465. exp2 = register_pointwise_numeric(aten.exp2)
  5466. expm1 = register_pointwise_numeric(aten.expm1)
  5467. relu = register_pointwise(aten.relu)
  5468. sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
  5469. sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
  5470. square = register_pointwise(aten.square)
  5471. sub = register_pointwise(aten.sub, allow_alpha=True)
  5472. register_pointwise_numeric_ldf64(aten.cos)
  5473. register_pointwise_numeric_ldf64(aten.sin)
  5474. abs = register_pointwise(aten.abs)
  5475. bitwise_and = register_pointwise(aten.bitwise_and)
  5476. bitwise_left_shift = register_pointwise(aten.bitwise_left_shift)
  5477. bitwise_not = register_pointwise(
  5478. aten.bitwise_not, override_fn_when_input_bool="logical_not"
  5479. )
  5480. bitwise_or = register_pointwise(aten.bitwise_or)
  5481. bitwise_right_shift = register_pointwise(aten.bitwise_right_shift)
  5482. bitwise_xor = register_pointwise(aten.bitwise_xor)
  5483. register_pointwise_numeric(aten.lgamma)
  5484. erf = register_pointwise_numeric(aten.erf)
  5485. register_lowering(
  5486. aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  5487. )(erf)
  5488. register_pointwise_numeric(aten.log1p)
  5489. register_pointwise_numeric(aten.tan)
  5490. register_pointwise_numeric(aten.tanh)
  5491. register_pointwise_numeric_ldf64(aten.log)
  5492. logical_and = register_pointwise(
  5493. aten.logical_and,
  5494. type_promotion_kind=None,
  5495. convert_input_to_bool=True,
  5496. override_return_dtype=torch.bool,
  5497. )
  5498. logical_not = register_pointwise(
  5499. aten.logical_not,
  5500. type_promotion_kind=None,
  5501. convert_input_to_bool=True,
  5502. override_return_dtype=torch.bool,
  5503. )
  5504. logical_or = register_pointwise(
  5505. aten.logical_or,
  5506. type_promotion_kind=None,
  5507. convert_input_to_bool=True,
  5508. override_return_dtype=torch.bool,
  5509. )
  5510. logical_xor = register_pointwise(
  5511. aten.logical_xor,
  5512. type_promotion_kind=None,
  5513. convert_input_to_bool=True,
  5514. override_return_dtype=torch.bool,
  5515. )
  5516. maximum = register_pointwise(aten.maximum)
  5517. minimum = register_pointwise(aten.minimum)
  5518. register_lowering(aten.clamp_min)(maximum)
  5519. register_lowering(aten.clamp_max)(minimum)
  5520. neg = register_pointwise(aten.neg)
  5521. abs = register_pointwise(aten.abs)
  5522. reciprocal = register_pointwise_numeric(aten.reciprocal)
  5523. register_pointwise(aten.remainder)
  5524. sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity")
  5525. register_pointwise(aten.ceil)
  5526. register_pointwise(aten.signbit, override_return_dtype=torch.bool)
  5527. register_lowering(aten._neg_view)(neg)
  5528. register_pointwise(aten.le, override_return_dtype=torch.bool)
  5529. register_pointwise(aten.lt, override_return_dtype=torch.bool)
  5530. register_pointwise(aten.ge, override_return_dtype=torch.bool)
  5531. gt = register_pointwise(aten.gt, override_return_dtype=torch.bool)
  5532. register_pointwise(aten.eq, override_return_dtype=torch.bool)
  5533. register_pointwise(aten.ne, override_return_dtype=torch.bool)
  5534. register_pointwise_numeric(aten.cosh)
  5535. register_pointwise_numeric(aten.sinh)
  5536. register_pointwise_numeric(aten.acos)
  5537. register_pointwise_numeric(aten.acosh)
  5538. register_pointwise_numeric(aten.asin)
  5539. register_pointwise_numeric(aten.asinh)
  5540. register_pointwise_numeric(aten.atan2)
  5541. register_pointwise_numeric(aten.atan)
  5542. register_pointwise_numeric(aten.atanh)
  5543. register_pointwise_numeric(aten.copysign)
  5544. register_pointwise_numeric(aten.erfc)
  5545. register_pointwise_numeric(aten.erfinv)
  5546. register_pointwise_numeric(aten.hypot)
  5547. register_pointwise_numeric(aten.log10)
  5548. register_pointwise_numeric(aten.log2)
  5549. register_pointwise_numeric(aten.nextafter)
  5550. from .codegen.common import BackendFeature, pointwise_overrides_data
  5551. def _get_pointwise_overrides(ns, name):
  5552. data = pointwise_overrides_data[name]
  5553. op = getattr(ns, data.name, None)
  5554. if op is None:
  5555. return
  5556. def make_triton_fallback(op):
  5557. if data.triton is None:
  5558. return fallback_handler(op)
  5559. if isinstance(op, torch._ops.OpOverloadPacket):
  5560. for olname in op.overloads():
  5561. ol = getattr(op, olname)
  5562. yield ol, data.type_promotion_kind, make_triton_fallback(ol)
  5563. else:
  5564. yield op, data.type_promotion_kind, make_triton_fallback(op)
  5565. for name in pointwise_overrides_data:
  5566. for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
  5567. aten, name
  5568. ):
  5569. register_pointwise(
  5570. op,
  5571. name=name,
  5572. type_promotion_kind=type_promotion_kind,
  5573. triton_fallback=triton_fallback,
  5574. )
  5575. for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
  5576. prims, name
  5577. ):
  5578. register_pointwise(
  5579. op,
  5580. name=name,
  5581. type_promotion_kind=type_promotion_kind,
  5582. triton_fallback=triton_fallback,
  5583. )
  5584. foreach_add_list = register_foreach_pointwise(
  5585. aten._foreach_add.List, add, allow_alpha=True
  5586. )
  5587. foreach_add_scalar = register_foreach_pointwise(
  5588. aten._foreach_add.Scalar, add, allow_alpha=True
  5589. )
  5590. register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True)
  5591. foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul)
  5592. register_foreach_pointwise(aten._foreach_mul.Tensor, mul)
  5593. foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul)
  5594. register_foreach_pointwise(aten._foreach_sub.List, sub)
  5595. register_foreach_pointwise(aten._foreach_sub.Scalar, sub)
  5596. register_foreach_pointwise(aten._foreach_neg.default, neg)
  5597. register_foreach_pointwise(aten._foreach_abs.default, abs)
  5598. register_foreach_pointwise(aten._foreach_pow.Scalar, pow)
  5599. register_foreach_pointwise(aten._foreach_pow.List, pow)
  5600. register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow)
  5601. foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
  5602. register_foreach_pointwise(aten._foreach_div.Tensor, div)
  5603. foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div)
  5604. register_foreach_pointwise(aten._foreach_sqrt, sqrt)
  5605. register_foreach_pointwise(aten._foreach_rsqrt, rsqrt)
  5606. register_foreach_pointwise(aten._foreach_maximum.List, maximum)
  5607. register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum)
  5608. register_foreach_pointwise(aten._foreach_minimum.List, minimum)
  5609. register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum)
  5610. register_foreach_pointwise(aten._foreach_clamp_min.List, maximum)
  5611. register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum)
  5612. register_foreach_pointwise(aten._foreach_clamp_max.List, minimum)
  5613. register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum)
  5614. register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
  5615. register_foreach_pointwise(aten._foreach_sign, sign)
  5616. foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy)
  5617. # these are only encountered as outputs of the graph
  5618. # reinplacing epilogue copies improves compile time
  5619. # by removing extra buffers sent to the scheduler.
  5620. def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op):
  5621. inplaceable_foreach_ops[outplace_aten_op] = aten_op
  5622. inplace_foreach_ops.add(aten_op)
  5623. def fn(*args, **kwargs):
  5624. results = outplace_op(*args, **kwargs)
  5625. mut_results = []
  5626. for arg, result in zip(args[0], results):
  5627. mut_results.append(mutate_to(arg, result, unsafe_alias=True))
  5628. return mut_results
  5629. _register_foreach_lowering(aten_op, fn)
  5630. register_foreach_inplace(
  5631. aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list
  5632. )
  5633. register_foreach_inplace(
  5634. aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar
  5635. )
  5636. register_foreach_inplace(
  5637. aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list
  5638. )
  5639. register_foreach_inplace(
  5640. aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar
  5641. )
  5642. register_foreach_inplace(
  5643. aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list
  5644. )
  5645. register_foreach_inplace(
  5646. aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar
  5647. )
  5648. register_foreach_inplace(
  5649. aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy
  5650. )
  5651. def register_inplace(aten_op, outplace_op):
  5652. @register_lowering(aten_op, type_promotion_kind=None)
  5653. def fn(*args, **kwargs):
  5654. result = outplace_op(*args, **kwargs)
  5655. result = to_dtype(result, args[0].get_dtype())
  5656. return mutate_to(args[0], result)
  5657. return fn
  5658. register_inplace(aten.add_, add)
  5659. register_inplace(aten.bitwise_and_, bitwise_and)
  5660. register_inplace(aten.bitwise_left_shift_, bitwise_left_shift)
  5661. register_inplace(aten.bitwise_not_, bitwise_not)
  5662. register_inplace(aten.bitwise_or_, bitwise_or)
  5663. register_inplace(aten.bitwise_right_shift_, bitwise_right_shift)
  5664. register_inplace(aten.bitwise_xor_, bitwise_xor)
  5665. register_inplace(aten.mul_, mul)
  5666. register_inplace(aten.div_.Tensor, div)
  5667. register_inplace(aten.div_.Tensor_mode, div_mode)
  5668. register_inplace(aten.logical_and_, logical_and)
  5669. register_inplace(aten.logical_not_, logical_not)
  5670. register_inplace(aten.logical_or_, logical_or)
  5671. register_inplace(aten.logical_xor_, logical_xor)
  5672. register_inplace(aten.sub_, sub)
  5673. register_inplace(aten.relu_, relu)
  5674. register_inplace(aten.sigmoid_, sigmoid)
  5675. register_lowering(aten.__and__)(bitwise_and)
  5676. register_lowering(aten.__lshift__)(bitwise_left_shift)
  5677. register_lowering(aten.__or__)(bitwise_or)
  5678. register_lowering(aten.__rshift__)(bitwise_right_shift)
  5679. register_lowering(aten.__xor__)(bitwise_xor)
  5680. register_inplace(aten.__iand__, aten.__and__)
  5681. register_inplace(aten.__ilshift__, aten.__lshift__)
  5682. register_inplace(aten.__ior__, aten.__or__)
  5683. register_inplace(aten.__irshift__, aten.__rshift__)
  5684. register_inplace(aten.__ixor__, aten.__xor__)
  5685. @register_lowering(aten.sym_constrain_range)
  5686. def sym_constrain_range(a, min=None, max=None):
  5687. return None
  5688. @register_lowering(aten.sym_size.int)
  5689. def sym_size(a, dim):
  5690. val = V.graph.current_node.meta["val"]
  5691. # Note [Can val be an int?]
  5692. # ~~~~~~~~~~~~~~~~~~~~~~~~~
  5693. # In principle, someone could construct an FX graph where
  5694. # a call to size/stride has a val that is a plain int (not
  5695. # SymInt). However, we will maintain the invariant that
  5696. # this is not possible: if you are constructing an FX graph
  5697. # where there is a call to size/stride that returns an
  5698. # int, but you KNOW that int must always be a constant,
  5699. # then you do not need trace that call at all (and just
  5700. # constant propagate the integer as is.)
  5701. assert isinstance(val, torch.SymInt), (
  5702. f"Expect val to be torch.SymInt but got val={val}"
  5703. )
  5704. return val.node.expr
  5705. @register_lowering(aten.sym_stride.int)
  5706. def sym_stride(a, dim):
  5707. val = V.graph.current_node.meta["val"]
  5708. # See Note [Can val be an int?]
  5709. assert isinstance(val, torch.SymInt), (
  5710. f"Expect val to be torch.SymInt but got val={val}"
  5711. )
  5712. return val.node.expr
  5713. @register_lowering(aten.sym_numel)
  5714. def sym_numel(a):
  5715. return a.get_numel()
  5716. for method, func in magic_methods.items():
  5717. register_lowering(method_to_operator(method))(func) # type: ignore[arg-type]
  5718. @register_lowering(torch.sym_sum)
  5719. def sym_sum(args):
  5720. return sympy.Add(*args)
  5721. @register_lowering(aten._foobar)
  5722. def foobar(self, *args, **kwargs):
  5723. raise NotImplementedError("Helpful for debugging")
  5724. @register_lowering(torch.ops._inductor_test.realize)
  5725. def _realize(x):
  5726. x.realize()
  5727. return clone(x)
  5728. @register_lowering(torch.ops.inductor.resize_storage_bytes_)
  5729. def resize_storage_bytes_(variable, new_size):
  5730. variable.realize()
  5731. ir.ResizeStorageBytes(variable, new_size)
  5732. return variable
  5733. @register_lowering(torch.ops.aten.set_.source_Tensor)
  5734. def set__source_tensor(self, source_tensor):
  5735. self.realize()
  5736. source_tensor.realize()
  5737. return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor))
  5738. if hasattr(torch.ops.fsdp, "copy_"):
  5739. @register_lowering(torch.ops.fsdp.copy_.default)
  5740. def fsdp_copy_(dst, src):
  5741. if dst is src:
  5742. # dst.copy_(dst) can happen from the reinplacing pass
  5743. return dst
  5744. src = to_device(src, dst.get_device())
  5745. src = to_dtype(src, dst.get_dtype())
  5746. src = expand(src, dst.get_size())
  5747. return mutate_to(dst, src)
  5748. @register_lowering(torch.ops.aten.resize)
  5749. def resize(x, size, *, memory_format=None):
  5750. assert isinstance(x, TensorBox)
  5751. assert isinstance(size, (list, tuple))
  5752. if memory_format is None:
  5753. memory_format = torch.contiguous_format
  5754. if memory_format == torch.preserve_format:
  5755. raise RuntimeError(f"unsupported memory format: {memory_format}")
  5756. if memory_format == torch.channels_last:
  5757. assert len(size) == 4
  5758. if memory_format == torch.channels_last_3d:
  5759. assert len(size) == 5
  5760. old_numel = x.get_numel()
  5761. dtype = x.get_dtype()
  5762. device = x.get_device_or_error()
  5763. if isinstance(x.data, ir.BaseView):
  5764. x.data = x.data.unwrap_view()
  5765. if (
  5766. torch.are_deterministic_algorithms_enabled()
  5767. and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined]
  5768. ):
  5769. if is_float_dtype(dtype):
  5770. uninitialized_val = float("nan")
  5771. elif is_integer_dtype(dtype):
  5772. uninitialized_val = torch.iinfo(dtype).max
  5773. else:
  5774. uninitialized_val = True
  5775. else:
  5776. # using zero as that is what empty does
  5777. uninitialized_val = 0.0
  5778. if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type]
  5779. return full(size, uninitialized_val, dtype=dtype, device=device)
  5780. x_flat = as_strided(
  5781. x,
  5782. [
  5783. old_numel,
  5784. ],
  5785. [
  5786. 1,
  5787. ],
  5788. )
  5789. flat_loader = x_flat.make_loader()
  5790. out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format)
  5791. out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer()
  5792. def inner_fn(idx):
  5793. flat_index = out_indexer(idx)
  5794. flat_index_expr = ops.index_expr(flat_index, torch.int64)
  5795. limit = ops.index_expr(old_numel, torch.int64)
  5796. mask = ops.lt(flat_index_expr, limit)
  5797. return ops.masked(mask, lambda: flat_loader([flat_index]), uninitialized_val)
  5798. out = Pointwise.create(
  5799. device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size)
  5800. )
  5801. return out
  5802. from torch._higher_order_ops.auto_functionalize import auto_functionalized
  5803. make_fallback(auto_functionalized)
  5804. @register_lowering(triton_kernel_wrapper_mutation)
  5805. def triton_kernel_wrap_(
  5806. *,
  5807. kernel_idx,
  5808. constant_args_idx,
  5809. grid,
  5810. tma_descriptor_metadata,
  5811. kwargs,
  5812. ):
  5813. from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
  5814. constant_args = kernel_side_table.get_constant_args(constant_args_idx)
  5815. ir.UserDefinedTritonKernel(
  5816. kernel_idx=kernel_idx,
  5817. grid=grid,
  5818. tma_descriptor_metadata=tma_descriptor_metadata,
  5819. kernel_args={**kwargs, **constant_args},
  5820. )
  5821. return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)}
  5822. @register_lowering(torch.ops.higher_order.cond, type_promotion_kind=None)
  5823. def cond(pred, true_fn, false_fn, operands):
  5824. if any(isinstance(x, IRNode) and is_triton(x) for x in [pred, *operands]):
  5825. msg = "control flow operator: torch.cond."
  5826. if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
  5827. msg = f"{msg} Found from : \n {stack_trace}"
  5828. V.graph.disable_cudagraphs_reason = msg
  5829. result = ir.Conditional.create(pred, true_fn, false_fn, operands)
  5830. return list(map(TensorBox.create, result))
  5831. @register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None)
  5832. def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False):
  5833. if any(
  5834. isinstance(x, IRNode) and is_triton(x)
  5835. for x in carried_inputs + additional_inputs
  5836. ):
  5837. msg = "control flow operator: torch.while_loop."
  5838. if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
  5839. msg = f"{msg} Found from : \n {stack_trace}"
  5840. V.graph.disable_cudagraphs_reason = msg
  5841. def _map_output(out: Any):
  5842. if isinstance(out, TensorBox):
  5843. return out
  5844. elif isinstance(out, ir.StorageBox):
  5845. return TensorBox(out)
  5846. elif isinstance(out, ir.MultiOutput):
  5847. return TensorBox.create(out)
  5848. else:
  5849. raise RuntimeError(f"NYI unsupported output type: {type(out)}")
  5850. result = ir.WhileLoop.create(
  5851. cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
  5852. )
  5853. assert isinstance(result, Sequence)
  5854. return list(map(_map_output, result))
  5855. register_lowering(
  5856. torch.ops.higher_order.while_loop_stack_output, type_promotion_kind=None
  5857. )(functools.partial(while_loop, stack_output=True))
  5858. @register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
  5859. def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
  5860. result = ir.InvokeSubgraph.create(subgraph_fn, *operands)
  5861. return list(map(TensorBox.create, result)) # type: ignore[call-overload]
  5862. @register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None)
  5863. def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None):
  5864. output = None
  5865. quant_options = V.graph.current_node.meta.get("quant_options", None)
  5866. assert quant_options is not None
  5867. for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
  5868. if node.op == "placeholder":
  5869. V.graph.env[node] = operands[i]
  5870. continue
  5871. # todo getattr
  5872. elif node.op == "output":
  5873. args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
  5874. for v in itertools.chain(args, kwargs.values()):
  5875. v.realize()
  5876. if quant_options.codegen_low_precision:
  5877. V.graph.low_precision_codegen_ops.add(v.get_operation_name())
  5878. V.graph.invoke_quant_ops.add(v.get_operation_name())
  5879. output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
  5880. else:
  5881. V.graph.env[node] = V.graph.run_node(node)
  5882. return output
  5883. @register_lowering(associative_scan_op, type_promotion_kind=None)
  5884. def associative_scan(
  5885. combine_fn: ir.Subgraph, xs, additional_inputs: tuple[torch.Tensor]
  5886. ):
  5887. from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
  5888. if len(additional_inputs) > 0:
  5889. raise RuntimeError(
  5890. "Unable to generate code for associative_scan op, because there are lifted arguments"
  5891. )
  5892. subgraph_inputs = [
  5893. InputDescriptor(dtype=x.get_dtype(), device=x.get_device())
  5894. for x in itertools.chain(xs, xs)
  5895. ]
  5896. lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated]
  5897. def wrapped_combine_fn(lhs, rhs):
  5898. return lowered_combine_fn(
  5899. *pytree.tree_leaves(lhs),
  5900. *pytree.tree_leaves(rhs),
  5901. )
  5902. kwargs = _make_scan_inner(xs[0], axis=0, dtype=None)
  5903. kwargs["dtypes"] = tuple(x.get_dtype() for x in xs)
  5904. kwargs["inner_fns"] = tuple(x.make_loader() for x in xs)
  5905. result = ir.Scan.create(
  5906. combine_fn=wrapped_combine_fn,
  5907. can_fallback_to_aten=False,
  5908. **kwargs,
  5909. )
  5910. if result[0] is None:
  5911. raise RuntimeError("Unable to generate code for associative_scan op")
  5912. return result
  5913. @register_lowering(torch.ops.prims._sink_tokens.default)
  5914. def _sink_tokens(tokens):
  5915. return None
  5916. @register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None)
  5917. def with_effects(token, op, *args, **kwargs):
  5918. result = ir.EffectfulKernel.create(op, *args, **kwargs)
  5919. from torch._higher_order_ops.effects import get_effect_key
  5920. effect_type = get_effect_key(op, args, kwargs)
  5921. assert effect_type is not None
  5922. effectful_kernel = V.graph.effectful_ops[effect_type]
  5923. if result is None:
  5924. return (effectful_kernel,)
  5925. result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result)
  5926. # See [NOTE: with_effects return type]
  5927. # Only return `result` if it is a tuple, not list.
  5928. if not isinstance(result, tuple):
  5929. return (effectful_kernel, result)
  5930. else:
  5931. return (effectful_kernel, *result)
  5932. from .comm_lowering import register_comm_lowerings
  5933. register_comm_lowerings()
  5934. @register_lowering(inductor_prims.prepare_softmax_online, type_promotion_kind=None)
  5935. def prepare_softmax_online(x, dim):
  5936. """
  5937. Lowering inductor_prims.prepare_softmax_online to compute max/sum in one pass if no split is needed.
  5938. """
  5939. kwargs = _make_reduction_inner(
  5940. x, axis=dim, keepdims=True, dtype=None, override_return_dtype=None
  5941. )
  5942. reduction_ranges = kwargs["reduction_ranges"]
  5943. rnumel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
  5944. hint, num_split = ir.Reduction.num_splits(
  5945. **kwargs,
  5946. reduction_type="online_softmax_reduce", # type: ignore[arg-type]
  5947. reduction_numel=rnumel,
  5948. )
  5949. if (
  5950. num_split == 1
  5951. and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold
  5952. ):
  5953. max_tensor, sum_tensor = OnlineSoftmaxReduction.create(
  5954. input_node=x, num_output=2, reduction_hint=hint, **kwargs
  5955. )
  5956. return max_tensor, sum_tensor
  5957. else:
  5958. # Note: [Split online_softmax_reduce]
  5959. # We don't split reduction for online_softmax_reduce for now.
  5960. # On one hand, supporting split reduction makes things complex since
  5961. # the split out reuctions requires 2 inputs rather than one.
  5962. # On the other hand, during training the online_softmax_reduce should
  5963. # usually don't requires a split due to large batch size
  5964. # (more specifically batch size times sequence length).
  5965. # We should support split reduction if we find legit use cases to
  5966. # motivate the work.
  5967. #
  5968. # TODO: does inference need split online_softmax_reduce?
  5969. warnings.warn(
  5970. textwrap.dedent(
  5971. """
  5972. Online softmax is disabled on the fly since Inductor decides to
  5973. split the reduction. Cut an issue to PyTorch if this is an
  5974. important use case and you want to speed it up with online
  5975. softmax.
  5976. """
  5977. )
  5978. )
  5979. amax = reduce_amax(x, dim, keepdims=True)
  5980. exp = lowerings[aten.exp](sub(x, amax))
  5981. xsum = sum_(exp, dim, keepdims=True)
  5982. return amax, xsum
  5983. # populate lowerings defined in kernel/*
  5984. from . import kernel
  5985. import_submodule(kernel)
  5986. from . import quantized_lowerings
  5987. quantized_lowerings.register_quantized_ops()
  5988. quantized_lowerings.register_woq_mm_ops()
  5989. from . import mkldnn_lowerings
  5990. mkldnn_lowerings.register_onednn_fusion_ops()
  5991. from . import jagged_lowerings
  5992. jagged_lowerings.register_jagged_ops()
  5993. @contextlib.contextmanager
  5994. def force_fallback(op: torch._ops.OpOverload):
  5995. """
  5996. A context manager to force fallback an op. Used in unit test
  5997. for FallbackKernel.
  5998. """
  5999. assert isinstance(op, torch._ops.OpOverload), (
  6000. "Only OpOverload to make the clean up easier"
  6001. )
  6002. old_handler = lowerings.get(op)
  6003. try:
  6004. register_lowering(op)(fallback_handler(op))
  6005. yield
  6006. finally:
  6007. if old_handler:
  6008. lowerings[op] = old_handler
  6009. else:
  6010. lowerings.pop(op)