| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298 |
- # mypy: allow-untyped-defs
- from __future__ import annotations
- import contextlib
- import dataclasses
- import functools
- import itertools
- import logging
- import math
- import operator
- import os
- import textwrap
- import warnings
- from collections import defaultdict
- from collections.abc import Iterable, Sequence
- from typing import Any, Callable, cast, Optional, TYPE_CHECKING, TypeVar, Union
- from typing_extensions import ParamSpec
- from unittest.mock import patch
- import sympy
- import torch
- import torch.ao.quantization.fx._decomposed
- import torch.fx
- import torch.utils._pytree as pytree
- from torch._dynamo.utils import counters
- from torch._higher_order_ops.associative_scan import associative_scan_op
- from torch._higher_order_ops.triton_kernel_wrap import triton_kernel_wrapper_mutation
- from torch._library.utils import get_layout_constraint_tag
- from torch._prims_common import (
- canonicalize_dim,
- canonicalize_dims,
- check,
- dtype_to_type,
- elementwise_dtypes,
- ELEMENTWISE_TYPE_PROMOTION_KIND,
- get_computation_dtype,
- is_boolean_dtype,
- is_float_dtype,
- is_integer_dtype,
- Number,
- )
- from torch.fx.experimental.sym_node import magic_methods, method_to_operator
- from torch.fx.experimental.symbolic_shapes import (
- free_unbacked_symbols,
- has_free_unbacked_symbols,
- resolve_unbacked_bindings,
- )
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._sympy.functions import CeilDiv, FloorDiv, Identity, ModularIndexing
- from .._dynamo.utils import import_submodule
- from . import config, inductor_prims, ir, test_operators # NOQA: F401
- from .decomposition import decompositions, get_decompositions
- from .ir import (
- BaseView,
- DtypeView,
- ExpandView,
- IndexingConstant,
- IRNode,
- is_triton,
- MutableBox,
- OnlineSoftmaxReduction,
- ops_wrapper,
- PermuteView,
- Pointwise,
- Reduction,
- ShapeAsConstantBuffer,
- SqueezeView,
- TensorBox,
- validate_ir,
- View,
- )
- from .utils import (
- ceildiv,
- decode_device,
- is_dynamic,
- is_gpu,
- is_pointwise_use,
- is_view,
- needs_fallback_due_to_atomic_add_limitations,
- pad_listlike,
- register_op_dtype_propagation_rules,
- register_op_requires_libdevice_fp64,
- sympy_product,
- use_scatter_fallback,
- )
- from .virtualized import ops, V
- if TYPE_CHECKING:
- from .ops_handler import ReductionType
- _T = TypeVar("_T")
- _P = ParamSpec("_P")
- # TODO(jansel): we should implement decomps or lowerings for these
- # https://github.com/pytorch/torchdynamo/issues/327
- FALLBACK_ALLOW_LIST = OrderedSet(
- [
- "torchvision::roi_align",
- "aten::index_add",
- ]
- )
- log = logging.getLogger(__name__)
- lowerings: dict[Union[Callable[..., Any], str], Callable[..., Any]] = {}
- # Use maybe_layout_constraints to access this dict, we lazily register tag-based layout constraints
- _maybe_layout_constraints: dict[
- torch._ops.OpOverload, Optional[Callable[..., Any]]
- ] = {}
- fallbacks = OrderedSet[torch._ops.OpOverload]()
- aten = torch.ops.aten
- tr_c10d = torch.ops.tr_c10d
- prims = torch.ops.prims
- needs_realized_inputs = OrderedSet[torch._ops.OpOverload]()
- foreach_ops = OrderedSet[torch._ops.OpOverload](
- [torch._higher_order_ops._foreach_map] # type: ignore[list-item]
- )
- # TODO(rec): torch._higher_order_ops._foreach_map is not an OpOverload
- # so why is it in foreach_ops?
- inplace_foreach_ops = OrderedSet[torch._ops.OpOverload]()
- inplaceable_foreach_ops: dict[torch._ops.OpOverload, torch._ops.OpOverload] = {}
- quantized_decomposed = torch.ops.quantized_decomposed
- def cur_node_has_non_foreach_users():
- for node in V.graph.current_node.users:
- for user in node.users:
- if not (user.op == "call_function" and (user.target in foreach_ops)):
- return True
- return False
- # group by device, whether any of the inputs are dynamic
- # note arg_pairs may or may not be a pair
- # foreach_map for example just passes output buffers here
- def group_foreach_args(arg_pairs: Iterable[Union[tuple[Any, Any], Any]]):
- out = defaultdict(list)
- unpack_args = False
- for i, args in enumerate(arg_pairs):
- if not isinstance(args, Iterable):
- unpack_args = True
- args = (args,)
- use_foreach = (
- not is_dynamic(*args) or config.combo_kernel_foreach_dynamic_shapes
- )
- device = None
- for t in args:
- if isinstance(t, TensorBox):
- device = t.data.get_device()
- break
- assert device is not None, "foreach op should have at least one tensor arg"
- if unpack_args:
- (args,) = args
- out[(device, use_foreach)].append((i, args))
- return out
- def maybe_layout_constraints(fn: Callable[..., Any]) -> Optional[Callable[..., Any]]:
- """Get layout constraints. Returns None if there are no layout constraints."""
- if not isinstance(fn, torch._ops.OpOverload):
- # Only OpOverloads have layout constraints.
- return None
- if maybe_layout_tag := get_layout_constraint_tag(fn, with_default=False):
- return tag_to_layout_constraint(maybe_layout_tag)
- if fn in _maybe_layout_constraints:
- return _maybe_layout_constraints[fn]
- return None
- def tag_to_layout_constraint(tag):
- if tag == torch._C.Tag.needs_exact_strides:
- return constrain_to_fake_tensors
- if tag == torch._C.Tag.needs_contiguous_strides: # type: ignore[attr-defined]
- return require_contiguous_strides
- if tag == torch._C.Tag.needs_fixed_stride_order:
- return constrain_to_fx_strides
- if tag == torch._C.Tag.flexible_layout:
- return None
- raise AssertionError(f"Unknown layout constraint tag: {tag}")
- def assert_nyi(cond, msg):
- if not cond:
- raise NotImplementedError(f"inductor does not support {msg}")
- def add_needs_realized_inputs(fn):
- if isinstance(fn, (list, set, tuple, OrderedSet)): # noqa: set_linter
- return [add_needs_realized_inputs(x) for x in fn]
- needs_realized_inputs.add(fn)
- if isinstance(fn, torch._ops.OpOverloadPacket):
- needs_realized_inputs.update(
- getattr(fn, overload) for overload in fn.overloads()
- )
- def add_layout_constraint(fn, constraint):
- if isinstance(fn, torch._ops.OpOverloadPacket):
- for overload in fn.overloads():
- _maybe_layout_constraints[getattr(fn, overload)] = constraint
- else:
- _maybe_layout_constraints[fn] = constraint
- add_needs_realized_inputs(
- [
- aten.as_strided,
- aten.as_strided_copy,
- aten.avg_pool2d,
- aten.avg_pool2d_backward,
- aten.bmm,
- aten.convolution,
- aten.convolution_backward,
- aten.max_pool2d_with_indices,
- aten.max_pool3d_with_indices,
- aten.max_pool2d_with_indices_backward,
- aten.mm,
- aten.upsample_nearest2d,
- aten._upsample_nearest_exact2d,
- aten._int_mm,
- ]
- )
- # TODO(jansel): ezyang says we won't need this in the future, try removing it
- # based on https://github.com/pytorch/pytorch/blob/9e3eb329df8f701/c10/core/ScalarType.h#L28
- DTYPE_ID_LOOKUP = {
- 0: torch.uint8,
- 1: torch.int8,
- 2: torch.int16,
- 3: torch.int32,
- 4: torch.int64,
- 5: torch.float16,
- 6: torch.float32,
- 7: torch.float64,
- 8: torch.complex32,
- 9: torch.complex64,
- 10: torch.complex32,
- 11: torch.bool,
- 15: torch.bfloat16,
- # TODO(jansel): add quantized types?
- # _(c10::qint8, QInt8) /* 12 */
- # _(c10::quint8, QUInt8) /* 13 */
- # _(c10::qint32, QInt32) /* 14 */
- # _(c10::quint4x2, QUInt4x2) /* 16 */
- # _(c10::quint2x4, QUInt2x4) /* 17 */
- }
- def decode_dtype(dtype: int):
- if not isinstance(dtype, int):
- return dtype
- assert dtype in DTYPE_ID_LOOKUP, f"id {dtype} missing from DTYPE_ID_LOOKUP"
- dtype = DTYPE_ID_LOOKUP[dtype]
- return dtype
- def is_integer_type(x):
- if isinstance(x, TensorBox):
- return is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- elif isinstance(x, sympy.Expr):
- return x.is_integer is True # type: ignore[attr-defined]
- else:
- return isinstance(x, int)
- def is_boolean_type(x):
- if isinstance(x, TensorBox):
- return is_boolean_dtype(x.get_dtype())
- else:
- return isinstance(x, bool)
- def get_promoted_dtype(*args, type_promotion_kind: ELEMENTWISE_TYPE_PROMOTION_KIND):
- def construct_input(inp):
- if isinstance(inp, (Number, sympy.Basic)):
- return inp
- else:
- dim = len(inp.get_size())
- # construct a tmp tensor to feed into torch.result_type
- return torch.zeros([1] * dim, dtype=inp.get_dtype())
- inps = [construct_input(arg) for arg in args]
- _, dtype = elementwise_dtypes(*inps, type_promotion_kind=type_promotion_kind)
- return dtype
- def get_overloads(aten_fn):
- if not isinstance(aten_fn, (list, tuple)):
- aten_fn = [aten_fn]
- else:
- aten_fn = list(aten_fn)
- for fn in list(aten_fn):
- if isinstance(fn, torch._ops.OpOverloadPacket):
- for overload in fn.overloads():
- other_fn = getattr(fn, overload)
- if other_fn not in lowerings:
- aten_fn.append(other_fn)
- return aten_fn
- def in_namespace(op, namespace):
- if isinstance(op, torch._ops.OpOverloadPacket):
- return namespace in op._qualified_op_name
- elif isinstance(op, torch._ops.OpOverload):
- return namespace in op.name()
- return False
- def maybe_copy_cpu_scalar(x: TensorBox, device: torch.device) -> TensorBox:
- """
- Copy cpu scalar if doesn't not match with given `device`
- """
- if not isinstance(x.data, ir.ReinterpretView) or has_free_unbacked_symbols(
- x.get_size()
- ):
- return x
- size = [V.graph.sizevars.size_hint_or_throw(s) for s in x.get_size()]
- cur_device = x.get_device()
- if (
- cur_device is not None
- and cur_device.type == "cpu"
- and cur_device != device
- and (len(size) == 0 or (len(size) == 1 and size[0] == 1))
- ):
- return TensorBox(ir.StorageBox(ir.DeviceCopy.create(x, cur_device, False)))
- return x
- def transform_args(
- args: list[Any],
- kwargs: dict[str, Any],
- broadcast: bool,
- type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
- convert_input_to_bool: bool,
- ) -> tuple[list[Any], dict[str, Any]]:
- """
- Transforms arguments for broadcasting and type promotion
- """
- args_indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
- kwargs_indices = [k for k, v in kwargs.items() if isinstance(v, TensorBox)]
- # check that there's something to transform
- if not args_indices and not kwargs_indices:
- return args, kwargs
- if type_promotion_kind or convert_input_to_bool:
- if convert_input_to_bool:
- dtype = torch.bool
- else:
- # FIXME this is a crude approximation for promoting args
- promoting_args = [
- a
- for a in args
- if isinstance(a, (Number, sympy.Basic)) or hasattr(a, "dtype")
- ]
- # only consider tensor kwargs for promotion, for now
- promoting_args.extend(a for a in kwargs.values() if hasattr(a, "dtype"))
- dtype = get_promoted_dtype(
- *promoting_args,
- type_promotion_kind=type_promotion_kind, # type: ignore[arg-type]
- )
- device = (
- args[args_indices[0]] if args_indices else kwargs[kwargs_indices[0]]
- ).get_device()
- for i in args_indices:
- args[i] = maybe_copy_cpu_scalar(args[i], device)
- for k in kwargs_indices:
- kwargs[k] = maybe_copy_cpu_scalar(kwargs[k], device)
- # sometimes args are an immutable list so we can't mutate them
- def promote(arg):
- if isinstance(arg, TensorBox):
- return to_dtype(arg, dtype)
- elif isinstance(arg, ir.Constant):
- return ir.Constant(value=arg.value, dtype=dtype, device=device)
- else:
- return arg
- args = [promote(a) for a in args]
- kwargs = {k: promote(v) for k, v in kwargs.items()}
- if broadcast:
- broadcasted = broadcast_tensors(
- *list(
- itertools.chain(
- (args[i] for i in args_indices),
- (kwargs[k] for k in kwargs_indices),
- )
- )
- )
- size = list(broadcasted[0].get_size())
- for i, x in zip(args_indices, broadcasted[: len(args_indices)]):
- args[i] = x
- for k, x in zip(kwargs_indices, broadcasted[len(args_indices) :]):
- kwargs[k] = x
- for i in range(len(args)):
- if isinstance(args[i], ir.Constant):
- args[i] = ExpandView.create(args[i], size)
- for k in kwargs:
- if isinstance(kwargs[k], ir.Constant):
- kwargs[k] = ExpandView.create(kwargs[k], size)
- return args, kwargs
- def _register_foreach_lowering(aten_fn, decomp_fn):
- """
- Add a foreach lowering to lowerings dict.
- Arguments:
- aten_fn: torch.ops.aten.* fn we are lowering
- decomp_fn: alternate implementation on our IR
- broadcast: True to apply broadcasting to tensor inputs
- type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
- convert_input_to_bool: some logical ops require inputs are converted to bool
- """
- @functools.wraps(decomp_fn)
- def wrapped(*args, **kwargs):
- assert len(args) <= 2
- out = decomp_fn(*args, **kwargs)
- validate_ir(out)
- return out
- aten_fns = get_overloads(aten_fn)
- foreach_ops.update(aten_fns)
- lowerings.update(dict.fromkeys(aten_fns, wrapped))
- return wrapped
- def _register_lowering(
- aten_fn,
- decomp_fn,
- broadcast,
- type_promotion_kind: Optional[ELEMENTWISE_TYPE_PROMOTION_KIND],
- convert_input_to_bool,
- lowering_dict,
- ):
- """
- Add a lowering to lowerings dict
- Arguments:
- aten_fn: torch.ops.aten.* fn we are lowering
- decomp_fn: alternate implementation on our IR
- broadcast: True to apply broadcasting to tensor inputs
- type_promotion_kind: kind of type promotion applied to tensor inputs, `None` means no type promotion
- convert_input_to_bool: some logical ops require inputs are converted to bool
- """
- @functools.wraps(decomp_fn)
- def wrapped(*args, **kwargs):
- args: list[Any] = list(args)
- kwargs: dict[str, Any] = dict(kwargs)
- unpacked = False
- # TODO maybe we need to use pytrees here
- if len(args) == 1 and isinstance(args[0], (list, tuple)):
- unpacked = True
- args = list(args[0])
- if not all(
- (fn in fallbacks or in_namespace(fn, "_c10d_functional")) for fn in aten_fn
- ):
- # explicitly assert for "out=" ops for better error messages
- assert not any(x == "out" for x in kwargs.keys()), (
- "out= ops aren't yet supported"
- )
- args, kwargs = transform_args(
- args, kwargs, broadcast, type_promotion_kind, convert_input_to_bool
- )
- if unpacked:
- args = [args]
- out = decomp_fn(*args, **kwargs)
- validate_ir(out)
- return out
- aten_fn = get_overloads(aten_fn)
- lowering_dict.update(dict.fromkeys(aten_fn, wrapped))
- return wrapped
- def register_lowering(
- aten_fn,
- broadcast=False,
- type_promotion_kind: Optional[
- ELEMENTWISE_TYPE_PROMOTION_KIND
- ] = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- convert_input_to_bool=False,
- lowering_dict=lowerings,
- ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
- """
- Shim to support decorator syntax.
- """
- return functools.partial(
- _register_lowering,
- aten_fn,
- broadcast=broadcast,
- type_promotion_kind=type_promotion_kind,
- convert_input_to_bool=convert_input_to_bool,
- lowering_dict=lowering_dict,
- )
- def broadcast_symbolic_shapes(a, b):
- """
- Broadcasting logic based on symbolic shapes.
- We give the shapes 0 and 1 concrete values, while all other shapes
- are symbolic sympy formulas.
- """
- output = []
- for x, y in itertools.zip_longest(reversed(a), reversed(b), fillvalue=sympy.S.One):
- if V.graph.sizevars.is_size_one_or_false(y):
- output.append(x)
- elif V.graph.sizevars.is_size_one_or_false(x):
- output.append(y)
- else:
- V.graph.sizevars.check_equals(x, y)
- if len(sympy.expand(y).free_symbols) < len(sympy.expand(x).free_symbols):
- output.append(y) # prefer shorter formula
- else:
- output.append(x)
- return tuple(reversed(output))
- def promote_constants(inputs, override_return_dtype=None, type_promotion_kind=None):
- assert override_return_dtype is None or type_promotion_kind is None, (
- "only one of override_return_dtype or type_promotion_kind may be given"
- )
- if override_return_dtype is None and type_promotion_kind is None:
- type_promotion_kind = ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- if not any(isinstance(x, (sympy.Basic, int, float)) for x in inputs):
- return inputs
- if all(isinstance(x, (int, float, sympy.Basic)) for x in inputs):
- dtype = override_return_dtype or get_promoted_dtype(
- *inputs, type_promotion_kind=type_promotion_kind
- )
- def const_func(x):
- if isinstance(x, sympy.Basic):
- return ir.IndexingConstant(
- index=x, dtype=dtype, device=decode_device(None)
- )
- else:
- return ir.Constant(value=x, dtype=dtype, device=decode_device(None))
- return [const_func(x) for x in inputs]
- ex = next(x for x in inputs if isinstance(x, (TensorBox, ExpandView, ir.Constant)))
- out = []
- for x in inputs:
- if isinstance(x, (int, float)):
- out.append(
- ExpandView.create(
- ir.Constant(
- value=x, dtype=ex.get_dtype(), device=ex.get_device_or_error()
- ),
- list(ex.get_size()),
- )
- )
- elif isinstance(x, sympy.Basic):
- out.append(
- ExpandView.create(
- IndexingConstant(
- index=x, dtype=ex.get_dtype(), device=ex.get_device_or_error()
- ),
- list(ex.get_size()),
- )
- )
- else:
- out.append(x)
- return out
- def make_pointwise(
- fn,
- override_return_dtype=None,
- override_device=None,
- override_fn_when_input_bool=None,
- allow_alpha=False,
- triton_fallback=None,
- ):
- def inner(*inputs: TensorBox, alpha=None):
- if triton_fallback is not None and any(
- isinstance(inp, IRNode) and is_triton(inp) for inp in inputs
- ):
- assert not allow_alpha # not implemented
- return triton_fallback(*inputs)
- inputs = promote_constants(inputs, override_return_dtype)
- if allow_alpha:
- if alpha is not None and alpha != 1:
- inputs = list(inputs)
- inputs[-1] = mul(inputs[-1], alpha)
- else:
- assert alpha is None
- loaders = [x.make_loader() for x in inputs]
- ranges = inputs[0].get_size()
- dtype = override_return_dtype or inputs[0].get_dtype()
- for other in inputs[1:]:
- assert isinstance(other, ir.BaseConstant) or len(ranges) == len(
- other.get_size()
- ), f"ndim mismatch {fn} {ranges} {other.get_size()}"
- # in tracing, we will annotate pointwise nodes that correspond to the output of
- # a pointwise node that would have been run in eager. intermediary pointwise nodes
- # during decompositions are not annotated.
- low_pr_fp = (torch.bfloat16, torch.float16)
- emulate_precision_casts = (
- V.graph is not None
- and getattr(V.graph, "current_node", None) is not None
- and V.graph.current_node.meta is not None
- and V.graph.current_node.meta.get("low_precision_pointwise_barrier", False)
- and dtype in low_pr_fp
- )
- def inner_fn(index):
- assert len(index) == len(ranges), f"wrong ndim {index} {ranges}"
- if dtype == torch.bool and override_fn_when_input_bool is not None:
- return override_fn_when_input_bool(*[load(index) for load in loaders])
- else:
- inputs_loaded = []
- for inp_index, load in enumerate(loaders):
- out = load(index)
- inp_dtype = inputs[inp_index].get_dtype()
- if emulate_precision_casts and inp_dtype in low_pr_fp:
- downcast = ops.to_dtype(out, inp_dtype, use_compute_types=False)
- out = ops.to_dtype(downcast, inp_dtype)
- inputs_loaded.append(out)
- out = fn(*inputs_loaded)
- if emulate_precision_casts:
- # fp16/bf16 kernels are computed in fp32. Casting down to fp16/bf16 here,
- # then upcasting again, to emulate casts that eager would do.
- downcast = ops.to_dtype(out, dtype, use_compute_types=False)
- return ops.to_dtype(downcast, dtype)
- return out
- if not override_device:
- device = None
- for i in inputs:
- if is_gpu(i.get_device().type):
- device = i.get_device()
- break
- if not device:
- device = inputs[0].get_device()
- device = override_device or device
- return Pointwise.create(
- device=device, # type: ignore[arg-type]
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=ranges,
- )
- return inner
- def make_foreach_pointwise(pw_fn, allow_alpha=False):
- def inner(*inputs: list[list[TensorBox]], alpha=1):
- realize_outputs = (
- len(V.graph.current_node.users) == 0
- or V.graph.current_node.target in inplace_foreach_ops
- or cur_node_has_non_foreach_users()
- )
- a_list_input = None
- for input in inputs:
- if isinstance(input, (list, tuple)):
- a_list_input = input
- break
- assert a_list_input is not None, (
- "at least one input must be a list to a foreach op"
- )
- # broadcast scalar inputs to match length of list inputs
- broadcast_inputs = []
- for input in inputs:
- if not isinstance(input, (list, tuple)):
- broadcast_inputs.append([input] * len(a_list_input))
- else:
- broadcast_inputs.append(input)
- groups = group_foreach_args(zip(*broadcast_inputs))
- outputs = [None] * len(a_list_input)
- for (device, use_foreach), group in groups.items():
- operation_list: list[str] = []
- for (
- output_ind,
- args,
- ) in group:
- if allow_alpha:
- output = pw_fn(*args, alpha=alpha)
- else:
- output = pw_fn(*args)
- outputs[output_ind] = output
- if (
- V.graph.has_feature(device, BackendFeature.FOREACH)
- and use_foreach
- and realize_outputs
- ):
- output.realize()
- operation_list.append(output.get_operation_name())
- if operation_list:
- V.graph.register_operation_list(operation_list)
- assert all(x is not None for x in outputs)
- return outputs
- return inner
- def to_dtype(
- x: Union[TensorBox, ShapeAsConstantBuffer], dtype: torch.dtype, copy: bool = False
- ):
- src_dtype = x.get_dtype()
- if src_dtype == dtype:
- return clone(x) if copy else x
- def _to_dtype(x):
- return ops.to_dtype(x, dtype, src_dtype=src_dtype)
- return make_pointwise(_to_dtype, override_return_dtype=dtype)(x)
- @register_lowering(torch._higher_order_ops._foreach_map, type_promotion_kind=None)
- def _foreach_map(subgraph, *args, **kwargs):
- """
- This lowers an invocation of foreach_map
- The way this works is that an arbitrary N-arg func is provided by the user, looped over by the
- polyfill with the same semantics as a foreach op (a loop applying an n-ary function to n args)
- and then traced into a subgraph by dynamo.
- This code allows us to inline the subgraph into the main graph lowering using the PontwiseSubgraphLowering.
- The graph outputs represent the vertically fused sequence of ops, and then register_operation_list
- below registers the buffers as horizontally fuseable in the scheduler.
- """
- from .subgraph_lowering import PointwiseSubgraphLowering
- inputs = args
- gm = subgraph.graph_module
- pw_subgraph = PointwiseSubgraphLowering(gm, root_graph_lowering=V.graph)
- with V.set_graph_handler(pw_subgraph): # type: ignore[arg-type]
- pw_subgraph.run(*inputs)
- sub_outputs = pw_subgraph.graph_outputs
- # group outputs by device and register as foreach
- assert sub_outputs # mypy lol
- groups = group_foreach_args(sub_outputs)
- outputs = [None] * len(sub_outputs)
- for (device, use_foreach), group in groups.items():
- operation_list: list[str] = []
- for (
- output_ind,
- output,
- ) in group:
- outputs[output_ind] = output
- if V.graph.has_feature(device, BackendFeature.FOREACH) and use_foreach:
- output.realize()
- operation_list.append(output.get_operation_name())
- if operation_list:
- V.graph.register_operation_list(operation_list)
- assert all(x is not None for x in outputs)
- return outputs
- @register_lowering(prims.convert_element_type, type_promotion_kind=None)
- def _convert_element_type(x: TensorBox, dtype: torch.dtype):
- if dtype.is_complex or x.get_dtype().is_complex:
- if x.get_size():
- # Decompose since aa aten fallback is more friendly for c++ codegen.
- # This decomposition doesn't work for empty tensor, which needs more investigation.
- dst = empty_like(x, dtype=dtype)
- ir.InplaceCopyFallback.create(dst, x)
- return dst
- else:
- return fallback_handler(
- prims.convert_element_type.default, add_to_fallback_set=False
- )(x, dtype)
- return to_dtype(x, dtype, copy=True)
- def to_dtype_bitcast(x: TensorBox, dtype: torch.dtype, *, copy=False):
- x_dtype = x.get_dtype()
- if x_dtype == dtype:
- return clone(x) if copy else x
- def _get_primitive_bitwidth(dtype):
- if dtype.is_floating_point:
- return torch.finfo(dtype).bits
- else:
- return torch.iinfo(dtype).bits
- src_bits = _get_primitive_bitwidth(x_dtype)
- dst_bits = _get_primitive_bitwidth(dtype)
- if src_bits != dst_bits:
- # fallback to aten eager implementation for differing bitwidths
- return fallback_handler(aten.view.dtype)(x, dtype)
- else:
- return TensorBox(DtypeView.create(x, dtype))
- @register_lowering(aten.view.dtype, type_promotion_kind=None)
- def _view_dtype(x: TensorBox, dtype: torch.dtype):
- if dtype.is_complex or x.get_dtype().is_complex:
- return TensorBox.create(
- ir.ComplexView.create(torch.ops.aten.view.dtype, x, dtype)
- )
- return to_dtype_bitcast(x, dtype)
- def to_device(x: TensorBox, device: torch.device, *, copy=False, non_blocking=False):
- device = decode_device(device)
- if x.get_device() == device:
- return clone(x) if copy else x
- return TensorBox.create(ir.DeviceCopy.create(x, device, non_blocking))
- @register_lowering(prims.device_put, type_promotion_kind=None)
- def _device_put(x: TensorBox, device: torch.device, non_blocking=False):
- return to_device(x, device, copy=True, non_blocking=non_blocking)
- def register_pointwise(
- aten_fn,
- name=None,
- broadcast=True,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
- convert_input_to_bool=False,
- override_return_dtype=None,
- override_fn_when_input_bool=None,
- allow_alpha=False,
- triton_fallback=None,
- ):
- """A pointwise function that maps ops.{name} to inputs"""
- name = name or aten_fn.__name__
- fn = ops_wrapper(name)
- register_op_dtype_propagation_rules(
- name, type_promotion_kind, override_return_dtype
- )
- if override_fn_when_input_bool is not None:
- override_fn_when_input_bool = ops_wrapper(override_fn_when_input_bool)
- fn = make_pointwise(
- fn,
- override_return_dtype=override_return_dtype,
- override_fn_when_input_bool=override_fn_when_input_bool,
- allow_alpha=allow_alpha,
- triton_fallback=triton_fallback,
- )
- fn = register_lowering(
- aten_fn,
- broadcast=broadcast,
- type_promotion_kind=type_promotion_kind,
- convert_input_to_bool=convert_input_to_bool,
- )(fn)
- if hasattr(prims, name):
- register_lowering(
- getattr(prims, name),
- type_promotion_kind=None,
- convert_input_to_bool=convert_input_to_bool,
- )(fn)
- return fn
- def register_frexp():
- """A pointwise function that maps ops.frexp to inputs"""
- name = "frexp"
- frexp = ops_wrapper("frexp")
- def frexp0(*args, **kwargs):
- return frexp(*args, **kwargs)[0] # type: ignore[index]
- def frexp1(*args, **kwargs):
- return frexp(*args, **kwargs)[1] # type: ignore[index]
- pw_fns = [
- make_pointwise(frexp0),
- make_pointwise(frexp1, override_return_dtype=torch.int32),
- ]
- def fn(*args, **kwargs):
- return pw_fns[0](*args, **kwargs), pw_fns[1](*args, **kwargs)
- fn = register_lowering(
- aten.frexp,
- )(fn)
- if hasattr(prims, name):
- register_lowering(
- getattr(prims, name),
- type_promotion_kind=None,
- )(fn)
- return fn
- register_frexp()
- def register_foreach_pointwise(
- aten_fn,
- pointwise_lowering_fn,
- allow_alpha=False,
- ):
- fn = make_foreach_pointwise(pointwise_lowering_fn, allow_alpha=allow_alpha)
- fn = _register_foreach_lowering(aten_fn, fn)
- return fn
- @register_lowering(aten.where, broadcast=False, type_promotion_kind=None)
- def where(cond, a, b):
- def fn(*args):
- return ops.where(*args)
- if isinstance(a, (float, int)):
- a = constant_like(a)(b)
- if isinstance(b, (float, int)):
- b = constant_like(b)(a)
- args = [cond, a, b]
- dtype = get_promoted_dtype(
- args[1], args[2], type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- indices = [i for i, x in enumerate(args) if isinstance(x, TensorBox)]
- for i, x in zip(indices, broadcast_tensors(*[args[i] for i in indices])):
- args[i] = x
- for i in range(len(args)):
- if isinstance(args[i], ir.Constant):
- args[i] = ExpandView.create(args[i], list(args[indices[0]].get_size()))
- return make_pointwise(fn, override_return_dtype=dtype)(
- args[0], to_dtype(args[1], dtype), to_dtype(args[2], dtype)
- )
- @register_lowering(aten.broadcast_tensors, broadcast=False, type_promotion_kind=None)
- def broadcast_tensors(*inputs):
- if len(inputs) == 1 and isinstance(inputs[0], (list, tuple)):
- return broadcast_tensors(*inputs[0])
- target: list[sympy.Expr] = functools.reduce(
- broadcast_symbolic_shapes, [x.get_size() for x in inputs], []
- )
- outputs = []
- for x in inputs:
- sizes = x.get_size()
- if len(sizes) != len(target) or any(
- V.graph.sizevars.is_size_one_or_false(a)
- != V.graph.sizevars.is_size_one_or_false(b)
- for a, b in zip(sizes, target)
- ):
- x = expand(x, target)
- outputs.append(x)
- return outputs
- @register_lowering([aten.alias, aten.detach, aten.detach_, aten.lift, prims.view_of])
- def nop(x):
- return x # AOT autograd handles this for us
- if hasattr(aten, "lift_fresh"):
- register_lowering(aten.lift_fresh)(nop)
- @register_lowering(aten.squeeze, type_promotion_kind=None)
- def squeeze(x, dim=None):
- assert isinstance(x, TensorBox)
- if dim is None:
- return TensorBox(SqueezeView.create(x.data))
- dim = (
- V.graph.sizevars.guard_int(dim)
- if isinstance(dim, (int, sympy.Expr))
- else tuple(V.graph.sizevars.guard_int(d) for d in dim)
- )
- dim = canonicalize_dims(len(x.get_size()), dim) # type: ignore[call-overload]
- dims = OrderedSet((dim,) if not isinstance(dim, tuple) else dim)
- new_shape = []
- for d, s in enumerate(x.get_size()):
- if not (d in dims and V.graph.sizevars.guard_or_false(sympy.Eq(s, 1))):
- new_shape.append(s)
- # squeeze does nothing if the size isn't 1
- return view(x, new_shape) if new_shape != x.get_size() else x
- @register_lowering(aten.squeeze_copy, type_promotion_kind=None)
- def squeeze_copy(x, dim=None):
- return clone(squeeze(x, dim))
- @register_lowering([aten.squeeze_])
- def squeeze_(x, dim=None):
- val = squeeze(x, dim)
- assert isinstance(x, TensorBox)
- assert isinstance(val, TensorBox)
- x.data = val.data
- return x
- @register_lowering(aten.isinf)
- def isinf(x):
- if is_integer_type(x):
- return full_like(x, False, dtype=torch.bool)
- fn = ops_wrapper("isinf")
- return make_pointwise(fn, override_return_dtype=torch.bool)(x)
- @register_lowering(aten.isnan)
- def isnan(x):
- if is_integer_type(x):
- return full_like(x, False, dtype=torch.bool)
- fn = ops_wrapper("isnan")
- return make_pointwise(fn, override_return_dtype=torch.bool)(x)
- @register_lowering(aten.ceil)
- def ceil(x):
- if is_integer_type(x):
- return clone(x)
- fn = ops_wrapper("ceil")
- return make_pointwise(fn)(x)
- @register_lowering(aten.floor)
- def floor(x):
- if is_integer_type(x):
- return clone(x)
- fn = ops_wrapper("floor")
- return make_pointwise(fn)(x)
- @register_lowering(aten.round.default)
- def round(x):
- if is_integer_type(x):
- return clone(x)
- else:
- fn = ops_wrapper("round")
- return make_pointwise(fn)(x)
- @register_lowering(aten.trunc)
- def trunc(x):
- if is_integer_type(x):
- return clone(x)
- fn = ops_wrapper("trunc")
- return make_pointwise(fn)(x)
- @register_lowering(aten.expand, type_promotion_kind=None)
- def expand(x, sizes):
- (x,) = promote_constants([x])
- if isinstance(x, ir.BaseConstant):
- return ExpandView.create(x, tuple(sizes))
- assert isinstance(x, TensorBox)
- assert isinstance(sizes, (list, tuple))
- if tuple(x.get_size()) == tuple(sizes):
- return x
- if not free_unbacked_symbols(x.get_size()):
- x_size_product = V.graph.sizevars.size_hint_or_throw(
- sympy_product(x.get_size())
- )
- # TODO: It would be better to realize the input if any of its sizes
- # are unbacked, because typically the size will be non-zero. However,
- # this cannot be done directly as below as we'll choke on the size_hint
- # here
- if x_size_product > 0 and not free_unbacked_symbols(sizes):
- # maybe realize input before broadcasting it
- x.mark_reuse(
- V.graph.sizevars.size_hint_or_throw(sympy_product(sizes))
- // x_size_product
- )
- return TensorBox(ExpandView.create(x.data, tuple(sizes)))
- @register_lowering(prims.broadcast_in_dim, type_promotion_kind=None)
- def broadcast_in_dim(a, shape, broadcast_dimensions):
- s = list(shape)
- for broadcast_dimension in broadcast_dimensions:
- s[broadcast_dimension] = -1
- v = a
- for idx, x in enumerate(s):
- if x != -1:
- v = unsqueeze(v, idx)
- return expand(v, shape)
- @register_lowering(aten.expand_as, type_promotion_kind=None)
- def expand_as(x, y):
- return expand(x, y.get_size())
- @register_lowering(aten.repeat)
- def repeat(x, repeats):
- old_size = list(x.get_size())
- if len(repeats) > len(old_size):
- old_size = [sympy.S.One] * (len(repeats) - len(old_size)) + old_size
- x = view(x, list(old_size))
- assert len(repeats) == len(x.get_size())
- new_size = list(x.get_size())
- zero_tensor = False
- for i in range(len(repeats)):
- if repeats[i] == 0:
- zero_tensor = True
- new_size[i] = new_size[i] * repeats[i]
- if zero_tensor:
- return empty(new_size, dtype=x.get_dtype(), device=x.get_device())
- if all((a == 1 or b == 1) for a, b in zip(repeats, old_size)):
- return clone(expand(x, new_size))
- x_loader: Callable[[Any], Any]
- def inner_fn(index):
- assert len(index) == len(repeats)
- index = list(index)
- for i in range(len(repeats)):
- if repeats[i] != 1:
- if old_size[i] == 1:
- index[i] = sympy.S.Zero
- else:
- index[i] = ModularIndexing(index[i], 1, old_size[i])
- return x_loader(index)
- if not free_unbacked_symbols(old_size) and not free_unbacked_symbols(new_size):
- old_size_product = V.graph.sizevars.size_hint_or_throw(sympy_product(old_size))
- if old_size_product > 0:
- # maybe realize the input but skip for unbacked symints since it'll
- # choke on the size hint.
- x.mark_reuse(
- V.graph.sizevars.size_hint_or_throw(sympy_product(new_size))
- // old_size_product
- )
- x_loader = x.make_loader()
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(new_size),
- )
- @register_lowering(aten._unsafe_view, type_promotion_kind=None)
- @register_lowering(aten.view, type_promotion_kind=None)
- @register_lowering(aten.reshape, type_promotion_kind=None)
- def view(x: TensorBox, sizes: Sequence[sympy.Expr]) -> TensorBox:
- return TensorBox(View.create(x.data, sizes))
- @register_lowering(aten.permute, type_promotion_kind=None)
- def permute(x, dims):
- assert isinstance(x, TensorBox)
- assert isinstance(dims, (list, tuple))
- return TensorBox(PermuteView.create(x.data, tuple(dims)))
- @register_lowering(aten.slice, type_promotion_kind=None)
- def slice_(x, dim=0, start=0, end=2**63, step=1, clamp=True):
- assert isinstance(x, TensorBox)
- dim = _validate_dim(x, dim, 0)
- return TensorBox(ir.SliceView.create(x.data, dim, start, end, step, clamp=clamp))
- @register_lowering(aten.as_strided, type_promotion_kind=None)
- def as_strided(x, size, stride, storage_offset=None):
- if isinstance(x, TensorBox) and isinstance(x.data, ir.BaseView):
- # as_strided ignores views
- x = x.data.unwrap_view()
- x.realize()
- if not ir.is_storage_and_layout(x):
- raise NotImplementedError(f"unrealized as_strided({x}, ...)")
- storage, old_layout = ir.as_storage_and_layout(x)
- new_layout = ir.FixedLayout(
- old_layout.device,
- old_layout.dtype,
- [sympy.expand(s) for s in size],
- [sympy.expand(s) for s in stride],
- sympy.expand(storage_offset or 0),
- )
- return TensorBox(ir.ReinterpretView(data=storage, layout=new_layout))
- @register_lowering(aten.as_strided_, type_promotion_kind=None)
- def as_strided_(x, size, stride, storage_offset=None):
- assert isinstance(x, TensorBox)
- x.data = as_strided(x, size, stride, storage_offset).data
- return x
- @register_lowering(aten.as_strided_copy, type_promotion_kind=None)
- def as_strided_copy(x, size, stride, storage_offset=None):
- result = as_strided(x, size, stride, storage_offset)
- return clone(result)
- def pointwise_cat(inputs, dim=0):
- # (inclusive, exclusive)
- inputs_ranges: list[tuple[sympy.Expr, sympy.Expr]] = []
- prev_end = 0
- for inp in inputs:
- inputs_ranges.append((prev_end, prev_end + inp.get_size()[dim])) # type: ignore[arg-type]
- prev_end = inputs_ranges[-1][-1] # type: ignore[assignment]
- inputs_loaders = [inp.make_loader() for inp in inputs]
- def inner_fn(idx):
- idx_dim = ops.index_expr(idx[dim], torch.int64)
- masks = []
- masked_loads = []
- for i in range(len(inputs)):
- start = (
- ops.constant(0, torch.int64)
- if i == 0
- else ops.index_expr(inputs_ranges[i][0], torch.int64)
- )
- end = ops.index_expr(inputs_ranges[i][1], torch.int64)
- start_cond = ops.ge(idx_dim, start)
- end_cond = ops.lt(idx_dim, end)
- if i == 0:
- mask = end_cond
- elif i == len(inputs) - 1:
- mask = start_cond
- else:
- mask = ops.and_(start_cond, end_cond)
- masks.append(mask)
- idx_load = list(idx)
- # if we're concatting [4], [2]
- # when we index the second tensor for 5 we want to index 5 - 4
- # Use Identity to prevent expansion of index * stride to keep expression
- # in same int bitwidth as shape
- idx_load[dim] = Identity(idx_load[dim] - inputs_ranges[i][0])
- masked_loads.append(
- ops.masked(
- mask,
- lambda: inputs_loaders[i](idx_load),
- 0.0, # this value should be unused
- ),
- )
- next_val = masked_loads[-1]
- for i in range((len(inputs)) - 2, -1, -1):
- next_val = ops.where(
- masks[i],
- masked_loads[i],
- next_val,
- )
- return next_val
- new_size = list(inputs[0].get_size())
- new_size[dim] = inputs_ranges[-1][-1]
- return Pointwise.create(
- device=inputs[0].get_device(),
- dtype=inputs[0].get_dtype(),
- inner_fn=inner_fn,
- ranges=new_size,
- )
- @register_lowering(quantized_decomposed.quantize_per_channel, type_promotion_kind=None)
- def quantized_decomposed_quantize_per_channel(
- input: TensorBox,
- scales: TensorBox,
- zero_points: TensorBox,
- axis: int,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- assert len(scales.get_size()) == 1, "expect scales 1 dim"
- assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
- if input.get_dtype() == torch.bfloat16:
- input = to_dtype(input, torch.float32)
- assert input.get_dtype() == torch.float32, (
- f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
- )
- assert axis < len(input.get_size()), (
- f"Expecting axis to be < {len(input.get_size())}"
- )
- input_loader = input.make_loader()
- scales_loader = scales.make_loader()
- zero_points_loader = zero_points.make_loader()
- def inner_fn(idx):
- channel_idx = (idx[axis],)
- input = input_loader(idx)
- scale = scales_loader(channel_idx)
- zero_point = zero_points_loader(channel_idx)
- qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
- if scales.dtype != torch.float32:
- scale = ops.to_dtype(scale, torch.float32)
- if zero_points.dtype != torch.int32:
- zero_point = ops.to_dtype(zero_point, torch.int32)
- inv_scale = ops.reciprocal(scale)
- val = ops.round(input * inv_scale) + zero_point
- clamped = ops.maximum(qmin, ops.minimum(qmax, val))
- return ops.to_dtype(clamped, dtype)
- return Pointwise.create(
- device=input.get_device(),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- def _assert_async(cond, msg):
- cond.realize()
- cond = to_dtype(cond, torch.bool)
- def inner_fn(index):
- with ir.ComputedBuffer.force_realize():
- return ops.device_assert_async(cond.make_loader()(index), msg)
- assertion_op = Pointwise.create(
- device=cond.get_device(),
- dtype=cond.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(cond.get_size()),
- )
- assertion_op.realize()
- return assertion_op
- @register_lowering(aten._assert_async.msg)
- def lower_assert_async(cond, msg):
- return _assert_async(cond, msg)
- @register_lowering(aten._functional_assert_async.msg)
- def lower_assert_functional_async(cond, msg):
- return _assert_async(cond, msg)
- @register_lowering(
- quantized_decomposed.dequantize_per_channel, type_promotion_kind=None
- )
- def quantized_decomposed_dequantize_per_channel(
- input: TensorBox,
- scales: TensorBox,
- zero_points: TensorBox,
- axis: int,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- *,
- out_dtype: Optional[torch.dtype] = None,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- assert len(scales.get_size()) == 1, "expect scales 1 dim"
- assert len(zero_points.get_size()) == 1, "expect zero_points 1 dim"
- assert input.get_dtype() == dtype, (
- f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
- )
- assert axis < len(input.get_size()), (
- f"Expecting axis to be < {len(input.get_size())}"
- )
- if out_dtype is None:
- out_dtype = torch.float32
- input_loader = input.make_loader()
- scales_loader = scales.make_loader()
- zero_points_loader = zero_points.make_loader()
- def inner_fn(idx):
- channel_idx = (idx[axis],)
- input = input_loader(idx)
- scale = scales_loader(channel_idx)
- zero_point = zero_points_loader(channel_idx)
- if scales.dtype != torch.float32:
- scale = ops.to_dtype(scale, torch.float32)
- if zero_points.dtype != torch.float32:
- zero_point = ops.to_dtype(zero_point, torch.float32)
- val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
- val = ops.to_dtype(val, out_dtype)
- return val
- return Pointwise.create(
- device=input.get_device(),
- dtype=out_dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- @register_lowering(
- quantized_decomposed.quantize_per_tensor.default, type_promotion_kind=None
- )
- def quantized_decomposed_quantize_per_tensor_default(
- input: TensorBox,
- scale: float,
- zero_point: int,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- if input.get_dtype() == torch.bfloat16:
- input = to_dtype(input, torch.float32)
- assert input.get_dtype() == torch.float32, (
- f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
- )
- input_loader = input.make_loader()
- def inner_fn(idx, scale, zero_point):
- input = input_loader(idx)
- inv_scale, zero_point = _create_constants(
- 1.0 / scale, zero_point, dtype=torch.float32
- )
- val = ops.round(input * inv_scale) + zero_point
- qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
- clamped = ops.minimum(ops.maximum(val, qmin), qmax)
- return ops.to_dtype(clamped, dtype)
- return Pointwise.create(
- device=input.get_device(),
- dtype=dtype,
- inner_fn=functools.partial(
- inner_fn, scale=float(scale), zero_point=int(zero_point)
- ),
- ranges=input.get_size(),
- )
- @register_lowering(
- quantized_decomposed.dequantize_per_tensor.default, type_promotion_kind=None
- )
- def quantized_decomposed_dequantize_per_tensor_default(
- input: TensorBox,
- scale: float,
- zero_point: int,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- *,
- out_dtype: Optional[torch.dtype] = None,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- assert input.get_dtype() == dtype, (
- f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
- )
- if out_dtype is None:
- out_dtype = torch.float32
- input_loader = input.make_loader()
- def inner_fn(idx, scale, zero_point):
- input = input_loader(idx)
- scale, zero_point = _create_constants(scale, zero_point, dtype=torch.float32)
- val = ops.sub(ops.to_dtype(input, torch.float32), zero_point) * scale
- val = ops.to_dtype(val, out_dtype)
- return val
- return Pointwise.create(
- device=input.get_device(),
- dtype=out_dtype,
- inner_fn=functools.partial(
- inner_fn, scale=float(scale), zero_point=int(zero_point)
- ),
- ranges=input.get_size(),
- )
- @register_lowering(
- quantized_decomposed.quantize_per_tensor.tensor, type_promotion_kind=None
- )
- def quantized_decomposed_quantize_per_tensor_tensor(
- input: TensorBox,
- scale: TensorBox,
- zero_point: TensorBox,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- if input.get_dtype() == torch.bfloat16:
- input = to_dtype(input, torch.float32)
- assert input.get_dtype() == torch.float32, (
- f"Expecting input to have dtype torch.float32, but got dtype: {input.get_dtype()}"
- )
- assert len(scale.get_size()) == 0 or (
- len(scale.get_size()) == 1 and scale.get_size()[0] == 1
- ), "expect scale as scalar tensor"
- assert len(zero_point.get_size()) == 0 or (
- len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
- ), "expect zero_point as scalar tensor"
- input_loader = input.make_loader()
- scale_loader = scale.make_loader()
- zero_point_loader = zero_point.make_loader()
- def inner_fn(idx):
- input = input_loader(idx)
- _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
- _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
- if scale.dtype != torch.float32:
- _scale = ops.to_dtype(_scale, torch.float32)
- if zero_point.dtype != torch.float32:
- _zero_point = ops.to_dtype(_zero_point, torch.float32)
- val = ops.round(input * ops.reciprocal(_scale)) + _zero_point
- qmin, qmax = _create_constants(quant_min, quant_max, dtype=torch.float32)
- clamped = ops.minimum(ops.maximum(val, qmin), qmax)
- return ops.to_dtype(clamped, dtype)
- return Pointwise.create(
- device=input.get_device(),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- @register_lowering(
- quantized_decomposed.dequantize_per_tensor.tensor, type_promotion_kind=None
- )
- def quantized_decomposed_dequantize_per_tensor_tensor(
- input: TensorBox,
- scale: TensorBox,
- zero_point: TensorBox,
- quant_min: int,
- quant_max: int,
- dtype: torch.dtype,
- *,
- out_dtype: Optional[torch.dtype] = None,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- assert len(scale.get_size()) == 0 or (
- len(scale.get_size()) == 1 and scale.get_size()[0] == 1
- ), "expect scale as scalar tensor"
- assert len(zero_point.get_size()) == 0 or (
- len(zero_point.get_size()) == 1 and zero_point.get_size()[0] == 1
- ), "expect zero_point as scalar tensor"
- assert input.get_dtype() == dtype, (
- f"Expecting input to have dtype {dtype}, but got dtype: {input.get_dtype()}"
- )
- if out_dtype is None:
- out_dtype = torch.float32
- input_loader = input.make_loader()
- scale_loader = scale.make_loader()
- zero_point_loader = zero_point.make_loader()
- def inner_fn(idx):
- input = input_loader(idx)
- _scale = scale_loader((0,) if len(scale.get_size()) == 1 else ())
- _zero_point = zero_point_loader((0,) if len(scale.get_size()) == 1 else ())
- if scale.dtype != torch.float32:
- _scale = ops.to_dtype(_scale, torch.float32)
- if zero_point.dtype != torch.float32:
- _zero_point = ops.to_dtype(_zero_point, torch.float32)
- val = ops.sub(ops.to_dtype(input, torch.float32), _zero_point) * _scale
- val = ops.to_dtype(val, out_dtype)
- return val
- return Pointwise.create(
- device=input.get_device(),
- dtype=out_dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- @register_lowering(aten.cat)
- def cat(inputs, dim=0):
- cpu_device = inputs[0].get_device().type == "cpu"
- if cpu_device and all(
- input.get_dtype() in [torch.int8, torch.uint8] for input in inputs
- ):
- # TODO <leslie> Remove this fallback when we support vectorization
- # code gen with uint8 data type directly.
- for input in inputs:
- input.realize()
- if all(len(input.get_size()) == 4 for input in inputs):
- inputs, _ = require_channels_last(aten.cat, *inputs)
- return fallback_handler(aten.cat.default)(inputs, dim)
- if len(inputs) == 1:
- return clone(inputs[0])
- dim = _validate_dim(inputs[0], dim, 0)
- dtype = get_promoted_dtype(
- *inputs, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
- )
- inputs = [to_dtype(inp, dtype) for inp in inputs]
- def unwrap_tensor(x: Union[TensorBox, ir.StorageBox]) -> ir.IRNode:
- if isinstance(x, TensorBox):
- if isinstance(x.data, ir.BaseView):
- return x.data.unwrap_view()
- else:
- return x.data
- if isinstance(x, ir.StorageBox):
- return x.data
- return x
- def is_reduction(t):
- return isinstance(t, ir.ComputedBuffer) and isinstance(t.data, ir.Reduction)
- def can_fuse_reduction(t):
- if isinstance(t, (TensorBox, ir.StorageBox)):
- return can_fuse_reduction(unwrap_tensor(t))
- return (
- is_reduction(t)
- or isinstance(t, ir.Pointwise)
- and any(
- can_fuse_reduction(V.graph.get_buffer(read))
- for read in t.get_read_names()
- )
- )
- # fusing reducutions into computed concat buffer can cause regressions.
- fusable_reduction = any(can_fuse_reduction(t) for t in inputs)
- def should_lower_cat_input(x) -> bool:
- # Unrealized inputs will not be storage and layouts, and we dont want to realize
- # them in case we want to fuse
- if ir.is_storage_and_layout(x):
- storage, _ = ir.as_storage_and_layout(x, freeze=False)
- return not ir.ConcatKernel.can_realize_into_without_copy(storage)
- if isinstance(x, (TensorBox, ir.StorageBox)):
- return should_lower_cat_input(unwrap_tensor(x))
- if isinstance(x, ir.Pointwise):
- return True
- return False
- if config.force_pointwise_cat:
- return pointwise_cat(inputs, dim)
- # TODO: We observed negative performance impact of pointwise_cat optimization on CPU so disabled it.
- # We will revisit this later after enabling vectorization on index_expr.
- if cpu_device:
- return TensorBox(ir.ConcatKernel.create(inputs, dim))
- def op_count(x):
- if isinstance(x, (TensorBox, ir.StorageBox)):
- return op_count(unwrap_tensor(x))
- # this will correspond to a direct memory read
- if not isinstance(x, ir.Pointwise):
- return 0
- count = x.inner_fn_opcount().num_ops
- for read in x.get_read_names():
- count += op_count(V.graph.get_buffer(read))
- return count
- # as of inputs increase, possibility for register spilling also increases
- # past a certain threshold of inputs we only fuse if the if the input kernels
- # are simple
- # not sure if we want to expose to users via config since logic may change in future
- MAX_COMPLEX_POINTWISE_CAT = 8
- MAX_SIMPLE_OP_COUNT = 2
- def additional_pointwise_ops(op: torch._ops.OpOverload):
- return op in (aten.cat.default, aten.constant_pad_nd.default)
- if len(inputs) <= MAX_COMPLEX_POINTWISE_CAT or (
- (len(inputs) <= config.max_pointwise_cat_inputs)
- and all(op_count(t) <= MAX_SIMPLE_OP_COUNT for t in inputs)
- ):
- pointwise_uses = all(
- is_pointwise_use(use, additional_pointwise_ops)
- for use in V.current_node.users
- )
- # fuse in case we will be used in a pointwise node, and there are any inputs we
- # we can prevent materialization of.
- fuse_pointwise_use = (
- any(should_lower_cat_input(inp) for inp in inputs) and pointwise_uses
- )
- # horizontal fuse in case all inputs will require a copy kernel anyway.
- # only horizontally fuse pointwise kernels
- horizontal_fuse_cat = all(
- should_lower_cat_input(inp) for inp in inputs
- ) and not any(can_fuse_reduction(t) for t in inputs)
- if fuse_pointwise_use or (horizontal_fuse_cat and not fusable_reduction):
- return pointwise_cat(inputs, dim)
- return TensorBox(ir.ConcatKernel.create(inputs, dim))
- @register_lowering(aten.diagonal, type_promotion_kind=None)
- def diagonal(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
- original_shape = input.get_size()
- num_dims = len(original_shape)
- dim1 = canonicalize_dim(idx=dim1, rank=num_dims)
- dim2 = canonicalize_dim(idx=dim2, rank=num_dims)
- check(
- dim1 != dim2, lambda: f"diagonal dimensions cannot be identical {dim1}, {dim2}"
- )
- offset_negative = V.graph.sizevars.evaluate_expr(sympy.Lt(offset, 0))
- if offset_negative:
- diag_size = V.graph.sizevars.evaluate_max(
- V.graph.sizevars.evaluate_min(
- original_shape[dim1] + offset, original_shape[dim2]
- ),
- 0, # type: ignore[arg-type]
- )
- else:
- diag_size = V.graph.sizevars.evaluate_max(
- V.graph.sizevars.evaluate_min(
- original_shape[dim1], original_shape[dim2] - offset
- ),
- 0, # type: ignore[arg-type]
- )
- base_idx = (0, 0)
- if offset_negative:
- base_idx = (-offset, 0)
- else:
- base_idx = (0, offset)
- sizes = [s for i, s in enumerate(original_shape) if i not in (dim1, dim2)]
- sizes.append(diag_size)
- def reindexer(idx):
- diag_idx = idx[-1]
- original_idx = [0] * len(original_shape)
- cur_dim = 0
- for d in range(num_dims):
- if d == dim1:
- original_idx[d] = diag_idx + base_idx[0]
- elif d == dim2:
- original_idx[d] = diag_idx + base_idx[1]
- else:
- original_idx[d] = idx[cur_dim]
- cur_dim += 1
- assert cur_dim == len(original_shape) - 2
- return original_idx
- return TensorBox(ir.GenericView.create(input, sizes, reindexer))
- @register_lowering(aten.diagonal_copy, type_promotion_kind=None)
- def diagonal_copy(input, offset: int = 0, dim1: int = 0, dim2: int = 1):
- return clone(diagonal(input, offset, dim1, dim2))
- @register_lowering(aten.diagonal_scatter, type_promotion_kind=None)
- def diagonal_scatter(input, src, offset: int = 0, dim1: int = 0, dim2: int = 1):
- output = clone(input)
- target = diagonal(output, offset, dim1, dim2)
- mutate_to(target, src)
- return output
- @register_lowering(aten.select, type_promotion_kind=None)
- def select(x, dim, idx):
- idx = sympy.expand(idx)
- size = sympy.expand(x.get_size()[dim])
- actual_index = None
- if V.graph.sizevars.guard_or_false(sympy.Lt(idx, 0)):
- actual_index = idx + size
- elif V.graph.sizevars.guard_or_false(sympy.Ge(idx, 0)):
- actual_index = idx
- if actual_index is not None:
- if has_free_unbacked_symbols(idx):
- # Inductor could generate incorrect views for tensors with unbacked symbols here;
- # Squeeze operations are translated to views, resulting in incorrect strides.
- # Additionally, we want to avoid accidental unbacked unsqueeze semantics. To resolve this,
- # we use as_strided instead.
- # Removing this branch will cause test_unbacked_select_index_with_check to fail.
- new_size = x.get_size()
- new_stride = x.get_stride()
- new_storage_offset = x.get_layout().offset + new_stride[dim] * actual_index
- del new_size[dim]
- del new_stride[dim]
- return as_strided(x, new_size, new_stride, new_storage_offset)
- else:
- slice_result = slice_(x, dim, actual_index, actual_index + 1)
- return squeeze(slice_result, dim)
- # Unbacked Semantics:
- # When the index idx is unbacked (e.g., u0), we compute the index dynamically
- # during the lowering of the select operation using DynamicSelectStorageOffset.
- unbacked_bindings = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
- )
- assert unbacked_bindings is not None
- assert len(unbacked_bindings) == 1, unbacked_bindings
- unbacked_offset_sym, _ = next(iter(unbacked_bindings.items()))
- new_size = x.get_size()
- new_stride = x.get_stride()
- new_storage_offset = unbacked_offset_sym
- buffer = ir.DynamicSelectStorageOffset(
- unbacked_offset_sym,
- idx,
- x.get_layout().offset,
- new_stride[dim],
- x.get_size()[dim],
- )
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- del new_size[dim]
- del new_stride[dim]
- return as_strided(x, new_size, new_stride, new_storage_offset)
- @register_lowering(aten.split, type_promotion_kind=None)
- def split(x, sizes, dim=0):
- dim = _validate_dim(x, dim, 0)
- sizes_ = sizes
- # If sizes is an integer (or a SymInt), we turn it into a list of sizes
- # by computing what the actual size of each chunk should be.
- if not isinstance(sizes, (list, tuple)):
- x_size = x.get_size()[dim]
- chunks = V.graph.sizevars.guard_int(FloorDiv(x_size + sizes - 1, sizes))
- sizes_ = [sizes] * chunks
- # The last chunk might have a smaller size than the rest.
- sizes_[-1] = x_size - (chunks - 1) * sizes
- # From this point, we assume that the sum of the sizes of all chunks
- # equals the size of the base tensor.
- result = []
- start = 0
- for size in sizes_:
- end = start + size
- # No need for clamping here, since we compute the exact
- # start and end values.
- result.append(slice_(x, dim, start, end, clamp=False))
- start = end
- return result
- @register_lowering(aten.split_with_sizes, type_promotion_kind=None)
- def split_with_sizes(x, sizes, dim=0):
- return split(x, sizes, dim)
- @register_lowering(aten.unbind, type_promotion_kind=None)
- def unbind(x, dim=0):
- dim = _validate_dim(x, dim, 0)
- x_size = V.graph.sizevars.guard_int(x.get_size()[dim])
- result = [select(x, dim, i) for i in range(x_size)]
- return result
- @register_lowering(aten.unfold, type_promotion_kind=None)
- def unfold(x, dimension, size, step):
- sizes = x.get_size()
- ndim = len(sizes)
- dim = canonicalize_dim(ndim, dimension)
- if ndim == 0:
- return slice_(unsqueeze(x, 0), end=size)
- dim_size = sizes[dim]
- sizevars = V.graph.sizevars
- sizevars.check_leq(size, dim_size)
- sizevars.check_lt(0, step) # type: ignore[arg-type]
- new_dim_size = FloorDiv(dim_size - size, step) + 1
- if sizevars.size_hint_or_throw(dim_size) > 0:
- x.mark_reuse(
- sizevars.size_hint_or_throw(CeilDiv(new_dim_size * size, dim_size))
- )
- out_size = [*sizes[:dim], new_dim_size, *sizes[dim + 1 :], size]
- def reindexer(idx):
- dim_idx = idx[-1] + idx[dim] * step
- return (*idx[:dim], dim_idx, *idx[dim + 1 : -1])
- return TensorBox(ir.GenericView.create(x, out_size, reindexer))
- @register_lowering(aten.unsqueeze, type_promotion_kind=None)
- def unsqueeze(x, dim):
- dim = _validate_dim(x, dim, 1)
- new_shape = list(x.get_size())
- new_shape.insert(dim, sympy.S.One)
- return view(x, new_shape)
- @register_lowering(aten.unsqueeze_, type_promotion_kind=None)
- def unsqueeze_(x, dim):
- val = unsqueeze(x, dim)
- assert isinstance(x, TensorBox)
- assert isinstance(val, TensorBox)
- x.data = val.data
- return x
- def _validate_dim(x, dim, offset=0):
- dim = V.graph.sizevars.shape_env.evaluate_expr(sympy.sympify(dim))
- ndim = len(x.get_size())
- if dim < 0:
- dim += ndim + offset
- assert 0 <= dim < ndim + offset
- return dim
- @register_lowering(aten.glu)
- def glu(x, dim=-1):
- dim = _validate_dim(x, dim, 0)
- # TODO: don't guard on static shape here
- new_len = V.graph.sizevars.guard_int(x.get_size()[dim]) // 2
- a = slice_(x, dim, 0, new_len)
- b = slice_(x, dim, new_len, new_len * 2)
- return mul(a, sigmoid(b))
- def fallback_handler(kernel, add_to_fallback_set=True):
- if add_to_fallback_set:
- fallbacks.add(kernel)
- def handler(*args, **kwargs):
- def wrap_tensors(x):
- return TensorBox.create(x) if isinstance(x, ir.IRNode) else x
- return pytree.tree_map(
- wrap_tensors, ir.FallbackKernel.create(kernel, *args, **kwargs)
- )
- # This lets us detect that a lowering is a fallback handler.
- handler._is_fallback_handler = True # type: ignore[attr-defined]
- return handler
- @functools.cache
- def _warn_complex_not_supported():
- warnings.warn(
- "Torchinductor does not support code generation for complex operators. Performance may be worse than eager."
- )
- # There are some types (CPU) which we accept as input but not as
- # output.
- def unsupported_input_tensor(t: torch.Tensor, node=None):
- "Do not support reading or writing to this tensor"
- if t.is_complex():
- # Complex views are supported with IR ComplexView
- _warn_complex_not_supported()
- return True
- if t.is_meta:
- return True
- if t.dtype == torch.float8_e8m0fnu:
- if not node:
- return True
- # allow bitcast, views, memory movement, but not arithmetic
- # TODO: delete once triton adds native support
- return not (
- isinstance(node.target, torch._ops.OpOverload)
- and node.target
- in (
- aten.view.dtype,
- aten.cat.default,
- aten.clone.default,
- aten._scaled_mm.default,
- )
- or (isinstance(node.target, torch._ops.OpOverload) and is_view(node.target))
- )
- return False
- def unsupported_output_tensor(t: torch.Tensor, node=None):
- "Do not support writing tensor but can read from it"
- supported_complex_views = (
- aten.view.dtype,
- torch.ops.prims.convert_element_type.default,
- )
- if node is not None and node.target in supported_complex_views and t.is_complex():
- return False
- if unsupported_input_tensor(t, node):
- return True
- return t.is_cpu and config.disable_cpp_codegen
- def fallback_node_due_to_unsupported_type(node: torch.fx.Node, allow_cpu_inputs=True):
- # Custom fallback lowering
- if node.target is aten.view_as_complex.default:
- return False
- if node.op == "placeholder":
- return False
- # We should be able to remove this special case once `disable_cpp_codegen` is killed.
- if node.target is aten.lift_fresh_copy.default:
- return False
- def check_skip_condition(inp_out_node, is_output):
- if not isinstance(inp_out_node, torch.fx.Node):
- return False
- if "val" not in inp_out_node.meta:
- return False
- for meta in pytree.tree_leaves(inp_out_node.meta["val"]):
- if not isinstance(meta, torch._subclasses.FakeTensor):
- continue
- if is_output:
- if unsupported_output_tensor(meta, node):
- return True
- else:
- if unsupported_input_tensor(meta, node):
- return True
- return False
- # only skip codegen if there is a cpu output, not input
- for arg in pytree.arg_tree_leaves(*node.args, **node.kwargs):
- if check_skip_condition(arg, is_output=False):
- return True
- return check_skip_condition(node, is_output=True)
- def make_fallback(op, layout_constraint=None, warn=True, override_decomp=False):
- assert op not in decompositions or override_decomp, (
- f"both a fallback and a decomp for same op: {op}"
- )
- if (
- warn
- and bool(os.getenv("CI"))
- and get_decompositions([op])
- # if fallback_random, we allow not decomposing random
- and not (
- config.fallback_random
- and op in torch._decomp.decompositions_for_rng.extra_random_decomps
- )
- and not override_decomp
- ):
- # Note: 'warn' is holdover from when this was a warning, but for ops that previously
- # set warn=False we do not want a CI error.
- # Ignore the 'suppress errors' configs in CI, as this particular warning happens on startup anyway and is not
- # likely to be triggered preferentially on one CI config over another.
- if torch._dynamo.config.suppress_errors:
- torch._dynamo.config.suppress_errors = False
- log.warning(
- "A make_fallback error occurred in suppress_errors config,"
- " and suppress_errors is being disabled to surface it."
- )
- raise AssertionError(
- f"make_fallback({op}): a decomposition exists, we should switch to it."
- " To fix this error, either add a decomposition to core_aten_decompositions (preferred)"
- " or inductor_decompositions, and delete the corresponding `make_fallback` line."
- " Get help from the inductor team if unsure, don't pick arbitrarily to unblock yourself.",
- )
- def register_fallback(op_overload):
- add_needs_realized_inputs(op_overload)
- if layout_constraint is not None:
- add_layout_constraint(op_overload, layout_constraint)
- return register_lowering(op_overload, type_promotion_kind=None)(
- fallback_handler(op_overload)
- )
- if isinstance(op, torch._ops.OpOverloadPacket):
- for ol in op.overloads():
- op_overload = getattr(op, ol)
- register_fallback(op_overload)
- elif isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
- register_fallback(op)
- else:
- raise RuntimeError(f"Unsupported fallback {op} with type {type(op)}")
- def philox_rand_offset(shape):
- """
- TorchInductor offset calculation differs from PyTorch eager offset
- calculation for random ops (tl.rand vs torch.rand). In future, we should
- strive for same impl for tl.rand and torch.rand.
- """
- numel = 1
- for s in shape:
- numel = numel * s
- return tensor(numel, dtype=torch.int64)
- @register_lowering(torch.ops.rngprims.philox_rand, type_promotion_kind=None)
- def philox_rand(size, seed, offset, stride, device, dtype):
- # stride arg is optional and will be used in future for distributed random
- # ops. Currently, its unused.
- random_pos = ir.FixedLayout(
- device,
- dtype,
- size,
- ir.FlexibleLayout.contiguous_strides(size),
- ).make_indexer()
- seed_loader = seed.make_loader()
- offset_loader = offset.make_loader()
- def inner_fn(index):
- # Both seed and offset in the philox_rand op are tensors.
- # torch seed and offsets are of type int64, but tl.rand accepts int32
- seed_index_expr = ops.to_dtype(seed_loader([]), torch.int32)
- offset_index_expr = ops.to_dtype(offset_loader([]), torch.int32)
- # Get the offset'd position
- rand_index_expr = ops.add(
- ops.index_expr(random_pos(index), torch.int32), offset_index_expr
- )
- result = ops.rand(
- seed_index_expr,
- rand_index_expr,
- )
- return ops.to_dtype(result, dtype)
- random_values_node = Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(size),
- )
- offset_node = philox_rand_offset(size)
- return random_values_node, offset_node
- @register_lowering(aten.native_dropout, type_promotion_kind=None)
- def native_dropout(x, p, train):
- if config.fallback_random:
- return pytree.tree_map(
- TensorBox.create,
- ir.FallbackKernel.create(aten.native_dropout.default, x, p, train),
- )
- else:
- raise AssertionError("should be handled in replace_random.py")
- @register_lowering(aten.bernoulli_, type_promotion_kind=None)
- def bernoulli_(x, *args):
- assert config.fallback_random or x.get_device() == torch.device("cpu"), (
- "this should be handled in decomps unless config.fallback_random or the device is CPU"
- )
- x.realize()
- op_overload = (
- aten.bernoulli_.float
- if len(args) == 0 or isinstance(args[0], float)
- else aten.bernoulli_.Tensor
- )
- ir.InplaceBernoulliFallback(op_overload, x, *args)
- return x
- @register_lowering(aten.bernoulli.p, type_promotion_kind=None)
- def bernoulli_p(x, *args):
- assert config.fallback_random or x.get_device() == torch.device("cpu"), (
- "this should be handled in decomps unless config.fallback_random or the device is CPU"
- )
- return bernoulli_(clone(x), *args)
- # This shouldn't be called in general
- @register_lowering(aten._foobar)
- def _foobar(_):
- raise AssertionError
- @functools.lru_cache(1)
- def _warn_triton_random(salt):
- log.info("using triton random, expect difference from eager")
- def warn_triton_random():
- # only warn once per graph
- _warn_triton_random(V.graph.creation_time)
- fallback_rand_default = fallback_handler(aten.rand.default)
- fallback_rand_generator = fallback_handler(aten.rand.generator)
- fallback_randn_default = fallback_handler(aten.randn.default)
- fallback_randn_generator = fallback_handler(aten.randn.generator)
- make_fallback(aten.randint)
- @register_lowering(aten.rand)
- def rand(*args, **kwargs):
- if kwargs.get("generator", None) is not None:
- return fallback_rand_generator(*args, **kwargs)
- elif config.fallback_random:
- kwargs.pop("generator", None)
- return fallback_rand_default(*args, **kwargs)
- raise AssertionError("should have been handled in replace_random.py")
- @register_lowering(aten.randn)
- def randn(*args, **kwargs):
- if kwargs.get("generator", None) is not None:
- return fallback_randn_generator(*args, **kwargs)
- elif config.fallback_random:
- kwargs.pop("generator", None)
- return fallback_randn_default(*args, **kwargs)
- raise AssertionError("should have been handled in replace_random.py")
- @register_lowering(inductor_prims.force_stride_order, type_promotion_kind=None)
- def inductor_force_stride_order(input_tensor, stride):
- stride_order = ir.get_stride_order(stride)
- return ir.ExternKernel.require_stride_order(input_tensor, stride_order)
- @register_lowering(inductor_prims.seed, type_promotion_kind=None)
- def inductor_seed(device: torch.device):
- raise AssertionError("should be handled in fuse_seed_creation_pass()")
- @register_lowering(inductor_prims.seeds, type_promotion_kind=None)
- def inductor_seeds(count, device):
- warn_triton_random()
- return TensorBox.create(ir.RandomSeeds(count, decode_device(device)))
- @register_lowering(inductor_prims.lookup_seed, type_promotion_kind=None)
- def inductor_lookup_seed(seeds, index):
- def inner_fn(_):
- return ops.load_seed(seeds.get_name(), index)
- return Pointwise.create(
- device=seeds.get_device(),
- dtype=seeds.get_dtype(),
- inner_fn=inner_fn,
- ranges=[],
- )
- @register_lowering(inductor_prims.random, type_promotion_kind=None)
- def inductor_random(size: list[int], seed: TensorBox, mode: str, *, offset: int = 0):
- assert not config.fallback_random
- assert mode in ("rand", "randn")
- size = [*size]
- dtype = torch.float32
- device = seed.get_device_or_error()
- random_pos = ir.FixedLayout(
- device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
- ).make_indexer()
- seed_loader = seed.make_loader()
- def inner_fn(index):
- return getattr(ops, mode)(
- seed_loader([]),
- ops.index_expr(random_pos(index), torch.int32),
- )
- result = Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=[*size],
- )
- result.realize()
- return result
- @register_lowering(inductor_prims.randint, type_promotion_kind=None)
- def inductor_randint(
- low: int, high: int, size: list[int], seed: TensorBox, *, offset: int = 0
- ):
- assert not config.fallback_random
- size = [*size]
- dtype = torch.int64
- device = seed.get_device_or_error()
- random_pos = ir.FixedLayout(
- device, dtype, size, ir.FlexibleLayout.contiguous_strides(size), offset=offset
- ).make_indexer()
- seed_loader = seed.make_loader()
- def inner_fn(index):
- return ops.randint64(
- seed_loader([]),
- ops.index_expr(random_pos(index), torch.int32),
- ops.index_expr(low, torch.int64),
- ops.index_expr(high, torch.int64),
- )
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=[*size],
- )
- def _boundaries_helper(tb: TensorBox) -> tuple[str, sympy.Expr, sympy.Expr, sympy.Expr]:
- return (
- tb.get_name(),
- tb.get_size()[-1],
- tb.get_size()[0] * tb.get_stride()[0],
- tb.get_stride()[-1],
- )
- def _sorter_helper(tb: TensorBox) -> tuple[str, sympy.Expr]:
- return tb.get_name(), tb.get_stride()[-1]
- @register_lowering(aten.searchsorted.Tensor, type_promotion_kind=None)
- def searchsorted(
- sorted_sequence: TensorBox,
- self: TensorBox,
- *,
- out_int32: bool = False,
- right: bool = False,
- side: Optional[str] = None,
- sorter: Optional[TensorBox] = None,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- validate_bucketize = lambda tb: V.graph.has_feature( # noqa: E731
- tb, BackendFeature.BUCKETIZE
- )
- if (
- not validate_bucketize(sorted_sequence)
- or not validate_bucketize(self)
- or (sorter is not None and not validate_bucketize(sorter))
- ):
- return fallback_handler(aten.searchsorted.Tensor, add_to_fallback_set=False)(
- sorted_sequence,
- self,
- out_int32=out_int32,
- right=right,
- side=side,
- sorter=sorter,
- )
- # If side is present, override the value of right if needed. This assumes that
- # validation of the two options being non-contradictory is already done by the
- # searchsorted meta-function.
- if side is not None and side == "right":
- right = True
- index_dtype = torch.int32 if out_int32 else torch.int64
- values_loader = self.make_loader()
- # The entire sorted_sequence tensor needs to be used by ops.bucketize, so we need to
- # realize it into global memory; or in other words, we can't guarantee that
- # sorted_sequence.get_name() (used below) will exist unless we call
- # sorted_sequence.realize().
- sorted_sequence.realize()
- if sorter is not None:
- sorter.realize()
- if len(sorted_sequence.get_size()) == 1:
- def inner_fn(idx):
- val = values_loader(idx)
- return ops.bucketize(
- val,
- _boundaries_helper(sorted_sequence),
- 0,
- index_dtype,
- right,
- sorter=None if sorter is None else _sorter_helper(sorter),
- sorter_indices=None if sorter is None else 0,
- )
- else:
- def inner_fn(idx):
- val = values_loader(idx)
- # Get index to the beginning of the sorted sequence within a flattened
- # version of the array.
- def get_flattened_index(tb: TensorBox):
- strides = tb.get_stride()
- return ops.index_expr(
- functools.reduce(
- operator.add, (s * i for s, i in zip(strides[:-1], idx[:-1]))
- ),
- index_dtype,
- )
- return ops.bucketize(
- val,
- _boundaries_helper(sorted_sequence),
- get_flattened_index(sorted_sequence),
- index_dtype,
- right,
- sorter=None if sorter is None else _sorter_helper(sorter),
- sorter_indices=None if sorter is None else get_flattened_index(sorter),
- )
- device = self.get_device()
- result = Pointwise.create(
- device=device,
- dtype=index_dtype,
- inner_fn=inner_fn,
- ranges=self.shape,
- )
- # see [NOTE: inductor bucketize realize]
- result.realize()
- return result
- @register_lowering(
- aten.bucketize, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.NO_OPMATH
- )
- def bucketize(
- input: TensorBox,
- boundaries: TensorBox,
- *,
- out_int32: bool = False,
- right: bool = False,
- ):
- assert len(boundaries.get_size()) == 1
- if not (
- V.graph.has_feature(input, BackendFeature.BUCKETIZE)
- and V.graph.has_feature(boundaries, BackendFeature.BUCKETIZE)
- ):
- return fallback_handler(aten.bucketize.Tensor, add_to_fallback_set=False)(
- input, boundaries, out_int32=out_int32, right=right
- )
- # The entire boundaries tensor needs to be used by ops.bucketize, so we
- # need to realize it into global memory; or in other words, we can't
- # guarantee that boundaries.get_name() (used below) will exist unless
- # we call boundaries.realize().
- boundaries.realize()
- device = input.get_device()
- input_loader = input.make_loader()
- index_dtype = torch.int32 if out_int32 else torch.int64
- def inner_fn(index):
- val = input_loader(index)
- indices = ops.bucketize(
- val,
- _boundaries_helper(boundaries),
- 0,
- index_dtype,
- right,
- )
- return indices
- result = Pointwise.create(
- device=device,
- dtype=index_dtype,
- inner_fn=inner_fn,
- ranges=input.get_size(),
- )
- # [NOTE: inductor bucketize realize]
- # bucketize_binary_search is relatively expensive, so we don't want to re-compute
- # it unnecessarily. If we run bucketize() and then broadcast the result, we don't
- # want this to be fused into a large number of duplicate bucketize() computations
- # for each of the elements in the result.
- #
- # If no broadcasting occurs, fusions can still occur in scheduler.py
- result.realize()
- return result
- def require_dense(_, *args, **kwargs):
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, ir.ExternKernel.require_stride1, (args, kwargs)
- )
- return args, kwargs
- def require_contiguous(_, *args, **kwargs):
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, ir.ExternKernel.require_contiguous, (args, kwargs)
- )
- return args, kwargs
- def require_contiguous_strides(_, *args, **kwargs):
- # TODO: combine this with require_contiguous after
- # https://github.com/pytorch/pytorch/pull/148235 lands.
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, ir.ExternKernel.require_contiguous_strides, (args, kwargs)
- )
- return args, kwargs
- def require_channels_last(_, *args, **kwargs):
- args, kwargs = pytree.tree_map_only(
- ir.IRNode, ir.ExternKernel.require_channels_last, (args, kwargs)
- )
- return args, kwargs
- def constrain_to_fake_tensor(arg, fake_arg):
- if isinstance(arg, ir.IRNode):
- meta_stride_expr = [
- s.node.expr if isinstance(s, torch.SymInt) else s for s in fake_arg.stride()
- ]
- return ir.ExternKernel.require_exact_strides(arg, meta_stride_expr)
- if isinstance(arg, dict):
- return {
- key: constrain_to_fake_tensor(arg[key], fake_arg[key]) for key in arg.keys()
- }
- elif isinstance(arg, (tuple, list)):
- return type(arg)(
- constrain_to_fake_tensor(a, f_a) for (a, f_a) in zip(arg, fake_arg)
- )
- return arg
- def constrain_to_fake_tensors(args, kwargs, fake_args, fake_kwargs):
- args = tuple(
- constrain_to_fake_tensor(arg, fake_arg)
- for arg, fake_arg in zip(args, fake_args)
- )
- kwargs = {k: constrain_to_fake_tensor(v, fake_kwargs[k]) for k, v in kwargs.items()}
- return args, kwargs
- def constrain_to_fx_strides(fx_node, *args, **kwargs):
- def apply_constraint(arg, fx_arg):
- if isinstance(arg, ir.IRNode):
- stride_order = ir.get_stride_order(
- fx_arg.meta["val"].stride(), V.graph.sizevars.shape_env
- )
- return ir.ExternKernel.require_stride_order(arg, stride_order)
- if isinstance(arg, dict):
- return {key: apply_constraint(arg[key], fx_arg[key]) for key in arg.keys()}
- return arg
- args = tuple(
- apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
- )
- kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
- return args, kwargs
- def sdpa_constraint(fx_node, *args, **kwargs):
- # sdpa requires dense last dimension]
- def apply_constraint(idx, arg, fx_arg):
- if not isinstance(arg, ir.IRNode):
- return arg
- meta_val = fx_arg.meta["val"]
- meta_stride_expr = [
- s.node.expr if isinstance(s, torch.SymInt) else s for s in meta_val.stride()
- ]
- stride_order = ir.get_stride_order(meta_val.stride())
- if stride_order and stride_order[-1] != 0:
- # contiguous stride order
- stride_order = list(reversed(range(len(arg.get_size()))))
- if (
- fx_node.target
- == aten._scaled_dot_product_efficient_attention_backward.default
- and idx in (0, 5)
- ):
- assert len(stride_order) == 4
- # The 0 and 5th arguments for aten._scaled_dot_product_efficient_attention_backward.default
- # are for out and gradient_out. They have to be in
- # (3, 1, 2, 0) stride order. Otherwise the kernel will crash.
- # Check https://github.com/pytorch/pytorch/issues/138772
- stride_order = (3, 1, 2, 0)
- if not meta_val.is_cuda:
- return ir.ExternKernel.require_stride_order(arg, stride_order)
- # This is the minimum alignment required by SDPA kernels for attention_bias.
- # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
- ALIGNMENT = 8
- # effn_attn_fwd does requires dense last dim, not just alignment
- effn_attn_fwd_bias = (
- fx_node.target
- == torch.ops.aten._scaled_dot_product_efficient_attention.default
- and idx == 3
- )
- assert isinstance(arg, TensorBox)
- if len(arg.get_size()) not in (3, 4):
- return arg
- is_aligned_tensor = ir.is_aligned_realized_tensor_hint(arg, ALIGNMENT)
- if is_aligned_tensor:
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- if (
- isinstance(arg, IRNode)
- and arg.maybe_get_stride() is not None
- and is_aligned_tensor
- ):
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- if effn_attn_fwd_bias:
- out_size = list(arg.get_size())
- expanded_dims = []
- # We require a dense last dimension, but the other strides
- # can be expanded, which results in a smaller tensor
- maybe_stride = arg.maybe_get_stride()
- for i in range(len(arg.get_size()) - 1):
- if V.graph.sizevars.statically_known_equals(meta_stride_expr[i], 0) or (
- maybe_stride is not None
- and V.graph.sizevars.statically_known_equals(maybe_stride[i], 0)
- ):
- expanded_dims.append(i)
- # Now, pad strides to alignment
- out_strides = [-1] * len(out_size)
- out_strides[-1] = 1
- stride = 1
- for i in range(len(out_size) - 2, -1, -1):
- if out_strides[i + 1] != 0:
- stride = stride * out_size[i + 1]
- # the expanded dims still need to be aligned, if they are,
- # we can make them expanded by setting the stride equal to 0
- if i in expanded_dims:
- if V.graph.sizevars.statically_known_equals(
- out_strides[i + 1] % ALIGNMENT, 0
- ):
- out_strides[i] = 0
- continue
- if not V.graph.sizevars.statically_known_equals(stride % ALIGNMENT, 0):
- stride = ceildiv(stride, ALIGNMENT) * ALIGNMENT
- out_strides[i] = stride
- return ir.ExternKernel.require_exact_strides(arg, out_strides)
- if is_aligned_tensor:
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- if (
- isinstance(arg, IRNode)
- and arg.maybe_get_stride() is not None
- and is_aligned_tensor
- ):
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- def is_aligned(x):
- return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0
- if isinstance(arg.data, ir.BaseView):
- if not is_aligned(arg):
- if is_aligned(arg.unwrap_view()):
- return ir.try_match_insignificant_strides(
- ir.ExternKernel.realize_input(arg), meta_stride_expr
- )
- return ir.ExternKernel.require_stride_order(arg, stride_order)
- args = tuple(
- apply_constraint(idx, arg, fx_arg)
- for idx, (arg, fx_arg) in enumerate(zip(args, fx_node.args))
- )
- kwargs = {k: apply_constraint(-1, v, fx_node.kwargs[k]) for k, v in kwargs.items()}
- return args, kwargs
- # WIP
- make_fallback(aten._adaptive_avg_pool3d) # @isuruf
- make_fallback(aten.adaptive_max_pool3d) # @isuruf
- make_fallback(aten._scaled_dot_product_attention_math_for_mps) # @malfet
- # 1) Easy
- make_fallback(aten.uniform, warn=False)
- make_fallback(aten.exponential.default, warn=False) # (fails accuracy on test_torch.py)
- make_fallback(aten._pdist_forward) # Has decomp. Needs benchmarks
- make_fallback(aten.soft_margin_loss_backward, warn=False) # py_impl?
- make_fallback(aten._fused_rms_norm, warn=False) # (MPS-only and faster than decomp)
- if torch.xpu.is_available():
- make_fallback(
- aten.embedding_dense_backward, warn=False
- ) # (XPU-only and faster than decomp)
- # 1.5) Easy or Impossible
- make_fallback(aten._cdist_forward) # p=2 should be feasible
- make_fallback(aten._cdist_backward)
- # 2) Medium
- make_fallback(aten._trilinear)
- # 3) Difficult
- # Scans
- # See the discussion at
- # https://dev-discuss.pytorch.org/t/pytorch-sparse-gnn-compiler-rfc/1644/19
- make_fallback(aten.segment_reduce.default)
- make_fallback(aten._segment_reduce_backward.default)
- # Histogram (need to implement Histogram IR)
- make_fallback(aten.histc)
- make_fallback(aten.histogram.bin_ct)
- make_fallback(aten._histogramdd_bin_edges.default)
- make_fallback(aten._histogramdd_from_bin_cts.default)
- # Need templated kernel
- make_fallback(aten.addbmm)
- make_fallback(aten._addmm_activation, warn=False)
- make_fallback(aten._grouped_mm, require_dense)
- # Need templated kernel. Probably impossible to write efficiently
- make_fallback(aten.convolution_backward, constrain_to_fx_strides)
- make_fallback(aten._cudnn_rnn, require_dense)
- make_fallback(aten._cudnn_rnn_backward, require_contiguous)
- # Haven't checked but sound difficult / impossible
- make_fallback(aten._embedding_bag, require_contiguous)
- make_fallback(aten._embedding_bag_forward_only, require_contiguous)
- make_fallback(aten._embedding_bag_backward)
- make_fallback(aten._embedding_bag_per_sample_weights_backward)
- make_fallback(aten._embedding_bag_per_sample_weights_backward)
- make_fallback(aten._fused_moving_avg_obs_fq_helper)
- make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
- # 4) Backwards (try py_impl'ing them) when fwd is written as a decomp
- make_fallback(aten.max_pool3d_with_indices_backward)
- make_fallback(aten._adaptive_avg_pool2d_backward, require_dense)
- make_fallback(aten._adaptive_avg_pool3d_backward)
- make_fallback(aten.adaptive_max_pool2d_backward)
- make_fallback(aten.adaptive_max_pool3d_backward)
- make_fallback(aten.fractional_max_pool2d_backward)
- make_fallback(aten.fractional_max_pool3d_backward)
- make_fallback(aten.replication_pad1d_backward)
- make_fallback(aten.replication_pad2d_backward)
- make_fallback(aten.upsample_linear1d_backward)
- make_fallback(aten.upsample_bicubic2d_backward, require_contiguous)
- make_fallback(aten.upsample_trilinear3d_backward)
- make_fallback(aten.grid_sampler_2d_backward, require_dense)
- make_fallback(aten._pdist_backward)
- # 5) Impossible (missing triton/CPU features)
- # Sorting / Sorting-like
- make_fallback(aten.sort)
- make_fallback(aten.sort.stable)
- make_fallback(aten.kthvalue)
- make_fallback(aten.topk)
- make_fallback(aten.mode)
- make_fallback(aten.median)
- make_fallback(aten.nanmedian)
- make_fallback(aten.randperm)
- # see: https://github.com/pytorch/pytorch/pull/121354
- make_fallback(aten.resize_)
- make_fallback(aten.resize_as_)
- # Linalg
- make_fallback(aten._linalg_det)
- make_fallback(aten.linalg_householder_product)
- make_fallback(aten.linalg_inv_ex)
- make_fallback(aten.linalg_ldl_factor_ex)
- make_fallback(aten.linalg_ldl_solve)
- make_fallback(aten.linalg_lu)
- make_fallback(aten.linalg_lu_factor_ex)
- make_fallback(aten.linalg_lu_solve)
- make_fallback(aten.linalg_matrix_exp)
- make_fallback(aten.linalg_qr)
- make_fallback(aten._linalg_slogdet)
- make_fallback(aten._linalg_solve_ex)
- make_fallback(aten.linalg_solve_triangular)
- make_fallback(aten._linalg_svd)
- make_fallback(aten.lu_unpack)
- make_fallback(aten.ormqr)
- make_fallback(aten._linalg_check_errors)
- make_fallback(aten.linalg_pinv.atol_rtol_tensor)
- make_fallback(aten._linalg_eigh)
- make_fallback(aten.triangular_solve)
- make_fallback(aten.linalg_cholesky_ex)
- make_fallback(aten.cholesky_inverse)
- make_fallback(aten.cholesky_solve)
- make_fallback(aten.geqrf)
- make_fallback(aten._fft_r2c) # needs complex as well
- # Data dependent (are these necessary?)
- make_fallback(aten.nonzero.default)
- # Misc
- make_fallback(aten.gcd.default, warn=False)
- make_fallback(aten._thnn_fused_lstm_cell, require_dense)
- make_fallback(torch._prims.rng_prims.run_and_save_rng_state)
- make_fallback(torch._prims.rng_prims.run_with_rng_state)
- make_fallback(torch._prims.rng_prims.graphsafe_run_with_rng_state)
- # Implemented / Half implemented
- # Scans. Implemented for CUDA, missing CPU
- make_fallback(aten.masked_scatter)
- make_fallback(aten.masked_scatter_backward)
- # Complex number support
- make_fallback(aten.view_as_complex, require_contiguous)
- make_fallback(aten.angle) # needs complex
- # Needs efficentzerotensor
- make_fallback(aten._efficientzerotensor)
- # Needs Sparse
- make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
- make_fallback(aten.to_sparse)
- make_fallback(aten._to_sparse)
- # Needs dimname support
- make_fallback(aten.zeros.names)
- # 6) Pattern-matched
- make_fallback(
- aten._scaled_dot_product_efficient_attention.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_efficient_attention_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_cudnn_attention.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_cudnn_attention_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention_for_cpu.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_flash_attention_for_cpu_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_fused_attention_overrideable.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(
- aten._scaled_dot_product_fused_attention_overrideable_backward.default,
- sdpa_constraint,
- warn=False,
- )
- make_fallback(aten._flash_attention_forward.default, sdpa_constraint)
- make_fallback(aten._flash_attention_backward.default, sdpa_constraint)
- make_fallback(aten._efficient_attention_forward.default, sdpa_constraint)
- make_fallback(aten._efficient_attention_backward.default, sdpa_constraint)
- # index_reduce requires fallback when use_scatter_fallback(...) returns True
- make_fallback(aten.index_reduce)
- make_fallback(aten.repeat_interleave.Tensor, override_decomp=True)
- # Register with type_promotion_kind None.
- # For example, fp16.copy_(fp32) should **not** promote the first input's dtype.
- @register_lowering(aten.copy, type_promotion_kind=None)
- def copy(self, src, non_blocking=False):
- x = src
- if self.get_device() != src.get_device():
- x = to_device(x, self.get_device())
- if self.get_dtype() != src.get_dtype():
- x = to_dtype(x, self.get_dtype())
- if self.get_size() != src.get_size():
- out = expand(x, self.get_size())
- return clone(out)
- return clone(x)
- @register_lowering(aten.clone)
- def clone(x, *, memory_format=None):
- # TODO(jansel): memory format
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=x.make_loader(),
- ranges=list(x.get_size()),
- )
- def clone_preserve_reinterpret_view(x):
- reinterpret_view_layouts = []
- if isinstance(x, TensorBox) and isinstance(x.data, ir.ReinterpretView):
- x = x.data # unwrap TensorBox
- while isinstance(x, ir.ReinterpretView):
- reinterpret_view_layouts.append(x.get_layout())
- x = x.data
- x = TensorBox(x)
- x = clone(x)
- if reinterpret_view_layouts:
- x = x.data # unwrap TensorBox
- for layout in reinterpret_view_layouts[::-1]:
- x = ir.ReinterpretView(data=x, layout=layout)
- x = TensorBox(x)
- return x
- if hasattr(aten, "lift_fresh_copy"):
- register_lowering(aten.lift_fresh_copy)(clone)
- @register_lowering(prims.iota)
- def iota(
- length,
- *,
- start,
- step,
- dtype,
- device,
- requires_grad,
- ):
- def fn(index):
- return ops.index_expr(step * index[0] + start, dtype=dtype)
- return Pointwise.create(
- device=decode_device(device),
- dtype=dtype,
- inner_fn=fn,
- ranges=[length],
- )
- @register_lowering(aten.select_scatter, type_promotion_kind=None)
- def select_scatter(x, src, dim: int, index: int):
- assert x.get_dtype() == src.get_dtype()
- x_loader = x.make_loader()
- dim = _validate_dim(x, dim, 0)
- if V.graph.sizevars.evaluate_expr(sympy.Lt(index, 0)):
- index = index + x.get_size()[dim]
- V.graph.sizevars.check_leq(0, index) # type: ignore[arg-type]
- V.graph.sizevars.check_lt(index, x.get_size()[dim]) # type: ignore[arg-type]
- src = expand(unsqueeze(src, dim), x.get_size())
- src_loader = src.make_loader()
- def inner_fn(idx):
- return ops.where(
- ops.eq(
- ops.index_expr(idx[dim], torch.int32),
- ops.index_expr(index, torch.int32),
- ),
- src_loader(idx),
- x_loader(idx),
- )
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(x.get_size()),
- )
- @register_lowering(aten.slice_scatter, type_promotion_kind=None)
- def slice_scatter(x, src, dim=0, start=None, end=None, step=1):
- src = to_dtype(src, x.get_dtype())
- x_loader = x.make_loader()
- dim = _validate_dim(x, dim, 0)
- dim_size = x.get_size()[dim]
- start, end = ir.SliceView.normalize_start_end(x, dim, start, end)
- src_size = list(x.get_size())
- src_size[dim] = FloorDiv(end - start + (step - 1), step)
- src = expand(src, src_size)
- src_loader = src.make_loader()
- def inner_fn(idx):
- if start == 0 and end == dim_size and step == 1:
- # selecting every element is the same as just src.clone()
- return src_loader(idx)
- idx_dim = ops.index_expr(idx[dim], torch.int64)
- src_idx = list(idx)
- src_idx[dim] = FloorDiv(idx[dim] - start, step)
- mask = []
- if start != 0:
- mask.append(
- ops.ge(
- idx_dim,
- ops.index_expr(sympy.expand(start), torch.int64),
- )
- )
- if end != dim_size:
- mask.append(
- ops.lt(
- idx_dim,
- ops.index_expr(sympy.expand(end), torch.int64),
- )
- )
- if step != 1:
- mask.append(
- ops.eq(
- ops.index_expr(
- ModularIndexing(idx[dim] - start, 1, step), torch.int64
- ),
- ops.constant(0, torch.int64),
- )
- )
- assert mask
- mask = functools.reduce(ops.and_, mask)
- src_val = ops.masked(
- mask,
- lambda: src_loader(src_idx),
- 0 if is_integer_type(x) else 0.0,
- )
- return ops.where(
- mask,
- src_val,
- x_loader(idx),
- )
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=list(x.get_size()),
- )
- def _unwrap(x):
- if isinstance(x, (list, tuple)) and len(x) > 0:
- return _unwrap(x[0])
- return x
- @register_lowering([torch.tensor, aten.scalar_tensor])
- def tensor(data, *, dtype=None, device=None, layout=None, pin_memory=False):
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- assert_nyi(not pin_memory, "pin_memory")
- if isinstance(_unwrap(data), int):
- dtype = dtype or torch.int64
- else:
- dtype = dtype or torch.get_default_dtype()
- ranges: list[sympy.Expr] = []
- if isinstance(data, sympy.Basic):
- def inner_fn(index):
- return ops.index_expr(data, dtype)
- elif isinstance(data, (float, int)):
- def inner_fn(index):
- return ops.constant(data, dtype)
- elif len(data) == 0 or isinstance(data[0], (float, int)) and len(data) <= 8:
- # inline small tensors
- ranges.append(sympy.Integer(len(data)))
- def inner_fn(index):
- def binary_search(start, end):
- assert start < end
- if end - start == 1:
- return ops.constant(data[start], dtype)
- mid = (end - start) // 2 + start
- return ops.where(
- ops.lt(
- ops.index_expr(index[0], torch.int64),
- ops.constant(mid, torch.int64),
- ),
- binary_search(start, mid),
- binary_search(mid, end),
- )
- if len(data) == 0:
- return ops.constant(0, dtype)
- return binary_search(0, len(data))
- else:
- return V.graph.add_tensor_constant(
- torch.tensor(data, dtype=dtype, device=device)
- )
- return Pointwise.create(
- device=decode_device(device),
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=ranges,
- )
- @register_lowering(torch.as_tensor)
- def as_tensor(data, dtype=None, device=None):
- if isinstance(data, TensorBox):
- if dtype is not None:
- data = to_dtype(data, dtype)
- if device is not None:
- data = to_device(data, device)
- return data
- return tensor(data, dtype=dtype, device=device)
- @register_lowering(torch.LongTensor)
- def long_tensor(data):
- return tensor(data, dtype=torch.int64)
- @register_lowering(aten._local_scalar_dense)
- def _local_scalar_dense(data):
- # This is interesting! Most lowerings return tensors, so you can just
- # return the buffer you allocated and it will get used (or not used, if
- # it's dead.) But _local_scalar_dense (aka item) returns an int,
- # not a Tensor, so you would have a type mismatch if you return a buffer;
- # we are obligated to return a sympy expression instead. However,
- # we need to actually codegen the .item() call somehow. We do this
- # by registering a faux buffer for the DynamicScalar IR node, which is
- # solely responsible for generating this .item(). The buffer is
- # not used for anything (notice we discard it); at codegen time,
- # the "buffer" just gets assigned None.
- unbacked_bindings = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env, V.graph.current_node.meta["unbacked_bindings"]
- )
- assert unbacked_bindings is not None
- assert len(unbacked_bindings) == 1, unbacked_bindings
- # NB: Have to be very careful here. V.graph.current_node.meta["val"]
- # seemingly also contains a symbol which you want to do binding for,
- # but it actually isn't. In particular, if we have later performed
- # a deferred runtime assert saying that u0 == s0, you will actually
- # see s0 from expr! This is bad because we need to actually generate
- # the assert that says u0 == s0, so we need to know where to get u0
- # from (this call). In particular, we must use unbacked_bindings, which
- # is guaranteed to have the original, unreplaced symbol in question.
- #
- # NB2: Another thing we have to be very careful about are symbol bindings
- # that require nontrivial refinement, e.g., when you have a binding site
- # x: Sym(u0 * 4) = y.item(). Here, the code generation must do a division
- # in order to appropriately bind u0. This is communicated via the keypath
- # in unbacked_bindings, and we need to hold onto it in order to generate
- # code appropriately for this case.
- binding_sym, keypath = next(iter(unbacked_bindings.items()))
- buffer = ir.DynamicScalar(binding_sym, keypath, data)
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- # NB: the replaced expr is OK to use directly downstream, we want
- # simplifications in this case!
- val = V.graph.current_node.meta["val"]
- if isinstance(val, (torch.SymInt, torch.SymFloat, torch.SymBool)):
- return val.node.expr
- else:
- return sympy.sympify(val)
- @register_lowering(aten._assert_scalar)
- def _assert_scalar(data, msg):
- # NB: These will be handled at codegen time
- # Not sure if we are guaranteed to be able to serve out truth from the
- # deferred_runtime_asserts, TODO: try this assert out
- # See [NOTE] Codegen runtime asserts in Inductor
- # assert bool(data.scalar), data
- return None
- @register_lowering(aten._assert_tensor_metadata)
- def _assert_tensor_metadata(
- a, size=None, stride=None, dtype=None, *, device=None, layout=None
- ):
- return None
- def _full(fill_value, device, dtype, size):
- value = fill_value
- if not isinstance(fill_value, (int, float)) and hasattr(value, "value"):
- value = value.value
- if isinstance(value, (int, float)):
- def inner_fn(index):
- return ops.constant(value, dtype)
- elif isinstance(value, sympy.Basic):
- def inner_fn(index):
- return ops.index_expr(value, dtype)
- else:
- assert len(value.get_size()) == 0
- value_loader = value.make_loader()
- def inner_fn(index):
- return value_loader([])
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(size),
- )
- def full_like(x, fill_value, **kwargs):
- return create_tensor_like(tensor_constructor(fill_value))(x, **kwargs)
- def tensor_constructor(fill_value):
- # torch.zeros, torch.ones, etc
- def inner(
- *size,
- names=None,
- dtype=None,
- device=None,
- layout=None,
- pin_memory=False,
- memory_format=None,
- ):
- assert_nyi(names is None, "named tensors")
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- assert_nyi(not pin_memory, "pin_memory")
- device = decode_device(device)
- dtype = dtype or torch.get_default_dtype()
- if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
- size = tuple(size[0])
- # See https://github.com/pytorch/pytorch/issues/118102
- # All sizes at lowering time should be sympy.Symbol, not SymInt!
- for s in size:
- assert not isinstance(s, torch.SymInt)
- size = [sympy.expand(s) for s in size]
- return _full(fill_value, device, dtype, size)
- return inner
- @register_lowering([torch.empty, aten.empty])
- def empty(
- *size,
- names=None,
- dtype=None,
- layout=None,
- device=None,
- pin_memory=None,
- memory_format=None,
- ):
- assert_nyi(names is None, "named tensors")
- device = decode_device(device)
- if len(size) == 1 and isinstance(size[0], (list, tuple, torch.Size)):
- size = tuple(size[0])
- return empty_strided(
- size, None, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
- )
- def create_tensor_like(creation_fn):
- """
- Shim to convert X_like(...) into X(...). For example zeros_like() into zeros().
- """
- def _constant_like(
- x, *, dtype=None, device=None, layout=None, pin_memory=False, memory_format=None
- ):
- assert_nyi(not pin_memory, "pin_memory")
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- if dtype is None:
- dtype = x.get_dtype()
- else:
- dtype = decode_dtype(dtype)
- device = device or x.get_device()
- size = list(x.get_size())
- return creation_fn(
- size, dtype=dtype, device=device, layout=layout, pin_memory=pin_memory
- )
- return _constant_like
- def constant_like(fill_value):
- return create_tensor_like(tensor_constructor(fill_value))
- empty_like = register_lowering(aten.empty_like)(create_tensor_like(empty))
- ones_like = create_tensor_like(tensor_constructor(1))
- zeros_like = create_tensor_like(tensor_constructor(0))
- def new_constant(fill_value):
- def _new_constant(
- x, size, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- assert isinstance(size, (list, tuple))
- assert_nyi(not pin_memory, "pin_memory")
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- dtype = decode_dtype(dtype) or x.get_dtype()
- device = device or x.get_device()
- size = [sympy.Integer(s) for s in size]
- return _full(fill_value, decode_device(device), dtype, size)
- return _new_constant
- @register_lowering(aten.new_empty)
- def new_empty(x, size, *, dtype=None, layout=None, device=None, pin_memory=None):
- if dtype is None:
- dtype = x.get_dtype()
- if device is None:
- device = x.get_device()
- return empty_strided(
- size,
- None,
- dtype=dtype,
- layout=layout,
- device=decode_device(device),
- pin_memory=pin_memory,
- )
- @register_lowering(aten.empty_strided)
- def empty_strided(
- size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- assert isinstance(size, (list, tuple))
- assert isinstance(stride, (list, tuple, type(None)))
- assert_nyi(not pin_memory, "pin_memory")
- assert_nyi(layout in (None, torch.strided), f"layout={layout}")
- dtype = decode_dtype(dtype) or torch.get_default_dtype()
- device = device or torch.tensor(0.0).device
- device = decode_device(device)
- pointwise = _full(fill_value=0, device=device, dtype=dtype, size=size)
- pointwise.realize()
- buffer = pointwise.data.data
- # explicitly set ranges to zeros in order to make a NopKernelSchedulerNode
- buffer.data = dataclasses.replace(buffer.data, ranges=[0] * len(size))
- assert isinstance(buffer, ir.ComputedBuffer)
- size = [sympy.expand(s) for s in size]
- stride = (
- [sympy.expand(s) for s in stride]
- if stride
- else ir.FlexibleLayout.contiguous_strides(size)
- )
- buffer.layout = ir.FixedLayout(
- device=device,
- dtype=dtype,
- size=size,
- stride=stride,
- )
- return pointwise
- @register_lowering(aten.new_empty_strided)
- def new_empty_strided(
- x, size, stride, *, dtype=None, layout=None, device=None, pin_memory=None
- ):
- if dtype is None:
- dtype = x.get_dtype()
- if device is None:
- device = x.get_device()
- return empty_strided(
- size,
- stride,
- dtype=dtype,
- layout=layout,
- device=decode_device(device),
- pin_memory=pin_memory,
- )
- @register_lowering(prims.copy_strided.default)
- def copy_strided(x, stride):
- stride = [V.graph.sizevars.size_hint_or_throw(s) for s in stride]
- stride_order = sorted(range(len(stride)), key=stride.__getitem__)
- return ir.ExternKernel.require_stride_order(x, stride_order)
- @register_lowering([torch.full, aten.full])
- def full(size, fill_value, **kwargs):
- assert kwargs.get("dtype") is not None, "dtype should be handled by decomposition"
- return tensor_constructor(fill_value)(size, **kwargs)
- @register_lowering(aten.gather, type_promotion_kind=None)
- def gather(x, dim, index, sparse_grad=False):
- # sparse_grad doesn't affect forward computation,
- # and backward tracing is taken care of by AOT Autograd
- assert isinstance(x, TensorBox)
- if index.get_numel() == 0:
- # Empty index case. Return an empty array with the same shape
- return new_empty(x, index.get_size())
- size = x.get_size()
- offset = len(size) == 0
- dim = _validate_dim(x, dim, offset)
- if offset:
- x = expand(x, [1])
- size = [1]
- x_loader = x.make_loader()
- index_loader = index.make_loader()
- def fn(idx):
- idx = list(idx)
- gather_idx = ops.indirect_indexing(index_loader(idx), size[dim])
- if len(idx) == 0:
- idx = [gather_idx]
- else:
- idx[dim] = gather_idx
- return x_loader(idx)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=index.get_size(),
- )
- @register_lowering(aten.embedding, type_promotion_kind=None)
- def embedding(weight, indices, padding_idx=-1, scale_grad_by_freq=False, sparse=False):
- if sparse:
- return fallback_handler(aten.embedding.default)(
- weight, indices, padding_idx, scale_grad_by_freq, sparse
- )
- assert not sparse
- assert isinstance(weight, TensorBox)
- assert isinstance(indices, TensorBox)
- assert "int" in str(indices.get_dtype())
- weight_loader = weight.make_loader()
- indices_loader = indices.make_loader()
- indices_ndim = len(indices.get_size())
- weight_size = weight.get_size()
- new_size = [*indices.get_size(), *weight_size[1:]]
- def fn(idx):
- assert len(idx) == len(new_size), f"{idx} != {new_size}"
- var_index = indices_loader(idx[:indices_ndim])
- weight_idx = [ops.indirect_indexing(var_index, weight_size[0])] + [
- *idx[indices_ndim:]
- ]
- return weight_loader(weight_idx)
- return Pointwise.create(
- device=weight.get_device(),
- dtype=weight.get_dtype(),
- inner_fn=fn,
- ranges=new_size,
- )
- def check_and_broadcast_indices(indices, device):
- assert all(
- i.get_dtype() in (torch.int64, torch.int32, torch.bool, torch.uint8)
- for i in indices
- if i is not None
- ), (
- f"indices must be int64, byte or bool. Got {[i.get_dtype() for i in indices if i is not None]}"
- )
- if any(
- i.get_dtype() in (torch.bool, torch.uint8) for i in indices if i is not None
- ):
- raise NotImplementedError("Fallback for bool indices")
- valid_idxs = [i for i, x in enumerate(indices) if isinstance(x, TensorBox)]
- assert len(valid_idxs) > 0, "requires at least 1 non-None index"
- new_indices = [None] * len(indices)
- for i, x in zip(valid_idxs, broadcast_tensors(*[indices[i] for i in valid_idxs])):
- # Eager allows indices to be CPU tensor when running on CUDA
- # FIXME: Calling to_device(x, device) should work but
- # test_advancedindex_mixed_cpu_devices still fails
- if x.get_device() != device:
- raise NotImplementedError("Fallback when indices is on a different device")
- new_indices[i] = x
- return new_indices, valid_idxs
- def index_output_size_and_inner_fn(
- x_size,
- indices,
- tensor_indices,
- tensor_size,
- indices_loaders,
- indexed_size,
- x_loader,
- check,
- wrap_neg=True,
- ):
- # Note that behavior of indexing differs when there are non consecutive
- # tensors. In this case, the tensor index is pulled to the beginning.
- #
- # Suppose a = torch.arange(3 * 4 * 5 * 6 * 7).view(3, 4, 5, 6, 7)
- # x = torch.tensor[1,2]
- # Then, a[:,x,:,x,:] will have shape 2,3,5,7 as due to x,:,x then 2 will
- # be pulled to the front.
- non_consecutive_tensors = False
- for previous, current in zip(tensor_indices, tensor_indices[1:]):
- if current - previous != 1:
- non_consecutive_tensors = True
- output_size = [x_size[i] for i, val in enumerate(indices) if val is None]
- output_size = [*output_size, *x_size[len(output_size) + len(tensor_indices) :]]
- first_tensor_index = tensor_indices[0]
- if non_consecutive_tensors:
- output_size = tensor_size + output_size
- else:
- output_size = (
- output_size[:first_tensor_index]
- + tensor_size
- + output_size[first_tensor_index:]
- )
- def fn(idx):
- assert len(idx) == len(output_size)
- assert len(indices_loaders) == len(indexed_size)
- rank = len(tensor_size)
- new_index = []
- first_tensor_index = tensor_indices[0]
- start_offset = 0 if non_consecutive_tensors else first_tensor_index
- next_idx = 0
- for i in range(tensor_indices[-1] + 1):
- if i == start_offset:
- next_idx += rank
- if indices[i] is None:
- assert next_idx < len(idx)
- new_index.append(idx[next_idx])
- next_idx += 1
- else:
- loader = indices_loaders[i]
- assert loader is not None
- size = indexed_size[i]
- new_index.append(
- ops.indirect_indexing(
- loader(idx[start_offset : start_offset + rank]),
- size,
- check=check,
- wrap_neg=wrap_neg,
- )
- )
- new_index = [
- *new_index,
- *idx[next_idx:],
- ]
- return new_index if x_loader is None else x_loader(new_index)
- return output_size, fn
- def index_impl(x, indices, check):
- output_size, inner_fn, _ = index_impl_helper(x, indices, check)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=inner_fn,
- ranges=output_size,
- )
- def index_impl_helper(x, indices, check, wrap_neg=True):
- assert isinstance(indices, (list, tuple))
- x_loader = x.make_loader()
- indices, tensor_indices = check_and_broadcast_indices(indices, x.get_device())
- assert len(tensor_indices) > 0, "Must have at least one valid idx"
- indices_loaders = [i.make_loader() if i is not None else None for i in indices]
- # no guards on output size, all the guards are set in broadcast_tensors
- # We can use the first one since they are all required to be the same size
- tensor_size = list(indices[tensor_indices[0]].get_size())
- x_size = x.get_size()
- indexed_size = [x_size[i] for i in range(len(indices)) if indices[i] is not None]
- if check and 0 in indexed_size and 0 not in tensor_size:
- raise IndexError("index is out of bounds for dimension with size 0")
- indexed_size = [x_size[i] for i in range(len(indices))]
- output_size, index_inner_fn = index_output_size_and_inner_fn(
- x_size,
- indices,
- tensor_indices,
- tensor_size,
- indices_loaders,
- indexed_size,
- None,
- check=check,
- wrap_neg=wrap_neg,
- )
- def inner_fn(idx):
- return x_loader(index_inner_fn(idx))
- return output_size, inner_fn, index_inner_fn
- @register_lowering(aten.index, type_promotion_kind=None)
- def index(x, indices):
- try:
- return index_impl(x, indices, check=True)
- except NotImplementedError:
- # Fallback to ATen for boolean indexing
- x.realize()
- return fallback_handler(aten.index.Tensor, add_to_fallback_set=False)(
- x, indices
- )
- @register_lowering(aten._unsafe_index, type_promotion_kind=None)
- def _unsafe_index(x, indices):
- return index_impl(x, indices, check=False)
- # All the indexing decompositions are written in terms of index, index_put, and index_put_
- # We cannot have this lowering as a decomposition as it introduces
- # mutation in the graph, which is bad for Aot Autograd. Aot Autograd runs dead
- # code elimination and common subexpression elimination optimizations, which
- # assume graphs to be side-effect free. More details at
- # https://github.com/pytorch/torchdynamo/issues/1235
- # and
- # https://github.com/pytorch/torchdynamo/issues/1863
- @register_lowering(aten.index_put, type_promotion_kind=None)
- def index_put(x, indices, values, accumulate=False):
- return index_put_impl_(
- clone(x), indices, values, accumulate, check=True, may_realize=False
- )
- @register_lowering(aten._unsafe_index_put)
- def _unsafe_index_put(x, indices, values, accumulate=False):
- return index_put_impl_(
- clone(x), indices, values, accumulate, check=False, may_realize=False
- )
- def index_put_as_masked_fill(self, indices, value, accumulate):
- if value.get_device() != self.get_device():
- value = to_device(value, self.get_device())
- if accumulate:
- value = add(self, value)
- return mutate_to(self, where(indices[0], value, self))
- def index_put_fallback(self, indices, values, accumulate):
- assert isinstance(V.graph.current_node.target, torch._ops.OpOverload)
- ir.IndexPutFallback(V.graph.current_node.target, self, indices, values, accumulate)
- return self
- @register_lowering(aten.index_put_, type_promotion_kind=None)
- def index_put_(self, indices, values, accumulate=False):
- return index_put_impl_(
- self, indices, values, accumulate, check=True, may_realize=True
- )
- @register_lowering(inductor_prims._unsafe_index_put_, type_promotion_kind=None)
- def _unsafe_index_put_(self, indices, values, accumulate=False):
- return index_put_impl_(
- self, indices, values, accumulate, check=False, may_realize=True
- )
- def index_put_impl_(self, indices, values, accumulate, check, may_realize=False):
- if may_realize:
- def try_get_name(x):
- if isinstance(x, ir.TensorBox):
- x = x.data
- if isinstance(x, ir.BaseView):
- x = x.unwrap_view()
- if isinstance(x, ir.StorageBox):
- x = x.data
- return x.get_name() if isinstance(x, ir.Buffer) else None
- def indice_slice_from_randperm(indice):
- # Refer to: https://github.com/pytorch/pytorch/pull/139366#discussion_r1825424660
- # For this specific pattern, indices is unique as coming from torch.randperm.
- # However, as the content of the indices is unknown, we have to check this specific pattern.
- if isinstance(indice, TensorBox) and isinstance(indice.data, ir.BaseView):
- indice = indice.data.unwrap_view()
- return (
- isinstance(indice, ir.StorageBox)
- and isinstance(indice.data, ir.ExternKernel)
- and getattr(indice.data, "fx_node", None)
- and indice.data.fx_node.target == torch.ops.aten.randperm.default
- )
- return False
- if try_get_name(self) in values.get_read_names() and not all(
- indice_slice_from_randperm(indice) for indice in indices
- ):
- # Fix issue: https://github.com/pytorch/pytorch/issues/138908
- # When self and values have memory overlapping, indices may
- # contain duplicate values, potentially causing incorrect results since
- # the load of `values` might contain modified value from the store of `self`.
- # To address this, store values in a temporary buffer in such cases.
- values.realize()
- # Dispatch to masked fill for single boolean index with single value
- if (
- values.get_numel() == 1
- and len(indices) == 1
- and indices[0].get_dtype() in (torch.bool, torch.uint8)
- ):
- mask = indices[0]
- for _ in range(len(mask.get_size()), len(self.get_size())):
- mask = unsqueeze(mask, -1)
- return index_put_as_masked_fill(self, [mask], values, accumulate)
- # Fallback in torch deterministic mode
- if torch.are_deterministic_algorithms_enabled():
- return index_put_fallback(self, indices, values, accumulate)
- # Fallback if there is a boolean index
- for index in indices:
- if index is not None and index.get_dtype() in (torch.bool, torch.uint8):
- return index_put_fallback(self, indices, values, accumulate)
- x_size = self.get_size()
- x_ndim = len(x_size)
- if accumulate and needs_fallback_due_to_atomic_add_limitations(self.get_dtype()):
- # self is an scalar Tensor
- if x_ndim == 0:
- self = view(self, [1])
- self = index_put_fallback(self, indices, values, accumulate)
- if x_ndim == 0:
- self = view(self, [])
- return self
- values = to_dtype(values, self.get_dtype())
- try:
- # Note that code will only get here when dtype is uint32
- indices, tensor_indices = check_and_broadcast_indices(
- indices, self.get_device()
- )
- except NotImplementedError:
- return index_put_fallback(self, indices, values, accumulate)
- indices_loaders = [i.make_loader() if i is not None else None for i in indices]
- assert isinstance(self, TensorBox)
- self.realize()
- # self is an scalar Tensor
- if x_ndim == 0:
- self = view(self, [1])
- # We can use the first one since they are all required to be the same size
- tensor_size = list(indices[tensor_indices[0]].get_size())
- indexed_size = [x_size[i] for i in range(len(indices))]
- expected_vals_size, inner_fn = index_output_size_and_inner_fn(
- x_size,
- indices,
- tensor_indices,
- tensor_size,
- indices_loaders,
- indexed_size,
- None,
- check=check,
- )
- values = expand(values, expected_vals_size)
- # all guards are set above during broadcast_tensors and expand
- device = self.get_device()
- assert device is not None
- scatter = ir.Scatter(
- device=device,
- dtype=self.get_dtype(),
- inner_fn=values.make_loader(),
- ranges=expected_vals_size, # iter_ranges,
- output_indexer=inner_fn,
- scatter_mode="atomic_add" if accumulate else None,
- )
- buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.MutationLayoutSHOULDREMOVE(self),
- data=scatter,
- )
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- if x_ndim == 0:
- self = view(self, [])
- return self
- fallback__unsafe_masked_index = fallback_handler(
- aten._unsafe_masked_index.default, add_to_fallback_set=False
- )
- fallback__unsafe_masked_index_put_accumulate = fallback_handler(
- aten._unsafe_masked_index_put_accumulate.default, add_to_fallback_set=False
- )
- @register_lowering(aten._unsafe_masked_index, type_promotion_kind=None)
- def _unsafe_masked_index(self, mask, indices, fill):
- ranges, _, _unsafe_index_fn = index_impl_helper(
- self, indices, check=False, wrap_neg=False
- )
- mask_loader = mask.make_loader()
- self_loader = self.make_loader()
- def inner_fn(idx):
- if mask.dtype != torch.bool:
- mask_val = ops.to_dtype(mask_loader(idx), torch.bool)
- else:
- mask_val = mask_loader(idx)
- return ops.masked(mask_val, lambda: self_loader(_unsafe_index_fn(idx)), fill)
- return Pointwise.create(
- device=self.get_device(),
- dtype=self.get_dtype(),
- inner_fn=inner_fn,
- ranges=ranges,
- )
- @register_lowering(aten._unsafe_masked_index_put_accumulate, type_promotion_kind=None)
- def _unsafe_masked_index_put_accumulate(x, mask, indices, values):
- masked_value = where(mask, values, 0)
- shape = x.get_size()
- clamped_indices = [
- clamp(indices[i], -shape[i], shape[i] - 1) if indices[i] else None
- for i in range(len(indices))
- ]
- # TODO: use a masked store for this. currently only triton
- # supports masked stores and cpp backend does not.
- return _unsafe_index_put(x, clamped_indices, masked_value, accumulate=True)
- @make_pointwise
- def clamp(a, min, max):
- return ops.maximum(min, ops.minimum(max, a))
- @register_lowering(aten.as_strided_scatter, type_promotion_kind=None)
- def as_strided_scatter(self, src, size, stride, storage_offset=None):
- output = clone(self)
- output_view = as_strided(output, size, stride, storage_offset)
- copy_(output_view, src)
- return output
- @register_lowering(aten.scatter, type_promotion_kind=None)
- def scatter(x, dim: int, index, src, **kwargs):
- return scatter_(clone(x), dim, index, src, **kwargs)
- def scatter_fallback(
- op_overload: torch._ops.OpOverload,
- self,
- dim: int,
- index,
- src,
- *,
- reduce: Optional[str] = None,
- include_self: bool = True,
- ):
- src_is_tensor = isinstance(src, TensorBox)
- if use_scatter_fallback(
- op_overload,
- reduce,
- self.get_dtype(),
- cast(torch.dtype, src.get_dtype() if src_is_tensor else type(src)),
- src.get_device().type if src_is_tensor else "not impl",
- src_is_tensor,
- ):
- ir.ScatterFallback(
- op_overload,
- self,
- dim,
- index,
- src,
- reduce=reduce,
- include_self=include_self,
- )
- return self
- return None
- @register_lowering(aten.scatter_, type_promotion_kind=None)
- def scatter_(self, dim: int, index, src, *, reduce: Optional[str] = None):
- assert reduce in (None, "add", "multiply")
- if reduce is None:
- op_overload = getattr(aten.scatter_, V.graph.current_node.target._overloadname) # type: ignore[union-attr]
- fallback_result = scatter_fallback(
- op_overload, self, dim, index, src, reduce=reduce
- )
- if fallback_result is not None:
- return fallback_result
- if reduce == "add":
- reduce = "sum"
- elif reduce == "multiply":
- reduce = "prod"
- return scatter_reduce_(self, dim, index, src, reduce)
- @register_lowering(aten.scatter_add, type_promotion_kind=None)
- def scatter_add(x, dim: int, index, src):
- return scatter_add_(clone(x), dim, index, src)
- @register_lowering(aten.scatter_add_, type_promotion_kind=None)
- def scatter_add_(x, dim: int, index, src):
- return scatter_reduce_(x, dim, index, src, "sum")
- @register_lowering(aten.scatter_reduce, type_promotion_kind=None)
- def scatter_reduce(x, dim: int, index, src, reduction_type, **kwargs):
- return scatter_reduce_(clone(x), dim, index, src, reduction_type, **kwargs)
- @register_lowering(aten.scatter_reduce_, type_promotion_kind=None)
- def scatter_reduce_(self, dim: int, index, src, reduce, *, include_self: bool = True):
- assert reduce in (None, "sum", "prod", "mean", "amax", "amin")
- assert (
- len(aten.scatter_reduce_.overloads()) == 1
- and "two" in aten.scatter_reduce_.overloads()
- ), "aten.scatter_reduce_.two is not the unique overload of aten.scatter_reduce_"
- if isinstance(src, Number):
- src = full_like(self, src)
- fallback_result = scatter_fallback(
- aten.scatter_reduce_.two,
- self,
- dim,
- index,
- src,
- reduce=reduce,
- include_self=include_self,
- )
- if fallback_result:
- return fallback_result
- assert isinstance(self, TensorBox)
- assert "int" in str(index.get_dtype())
- ndim = len(self.get_size())
- if ndim == 0:
- self = view(self, [1])
- if isinstance(src, TensorBox) and len(src.get_size()) == 0:
- src = view(src, [1])
- if isinstance(index, TensorBox) and len(index.get_size()) == 0:
- index = view(index, [1])
- if index.get_numel() == 0:
- return self
- dim = _validate_dim(self, dim)
- self.realize()
- index_loader = index.make_loader()
- src_loader = src.make_loader() if isinstance(src, TensorBox) else None
- def output_indexer(idx):
- # self is captured from the end of the function, so it may have 0 dim
- shape = self.get_size()
- ndim = len(shape)
- indirect_idx = list(idx)
- indirect_idx[dim] = ops.indirect_indexing(
- index_loader(idx), 1 if ndim == 0 else shape[dim], wrap_neg=False
- )
- return indirect_idx
- def fn(idx):
- if src_loader:
- return src_loader(idx)
- else:
- # src is a scalar
- return ops.constant(src, self.get_dtype())
- def backend_reduce_str(reduce):
- if reduce == "sum":
- return "atomic_add"
- else:
- # TODO: Need to support more reduction type
- assert reduce is None
- return None
- device = self.get_device()
- assert device is not None
- if not include_self:
- # zero out the corresponding elements first
- zero_out = ir.Scatter(
- device=device,
- dtype=self.get_dtype(),
- inner_fn=lambda index: ops.constant(0, self.get_dtype()),
- ranges=index.get_size(),
- output_indexer=output_indexer,
- scatter_mode=None,
- )
- buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.MutationLayoutSHOULDREMOVE(self),
- data=zero_out,
- )
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- # self[index[i][j][k]][j][k] += src[i][j][k] # if dim == 0
- # self[i][index[i][j][k]][k] += src[i][j][k] # if dim == 1
- # self[i][j][index[i][j][k]] += src[i][j][k] # if dim == 2
- scatter = ir.Scatter(
- device=device,
- dtype=self.get_dtype(),
- inner_fn=fn,
- ranges=index.get_size(),
- output_indexer=output_indexer,
- scatter_mode=backend_reduce_str(reduce),
- )
- buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.MutationLayoutSHOULDREMOVE(self),
- data=scatter,
- )
- buffer.name = V.graph.register_buffer(buffer)
- V.graph.register_operation(buffer)
- if ndim == 0:
- self = view(self, [])
- return self
- def upsample_nearestnd(
- x,
- output_size,
- scales_x: tuple[Optional[float], ...],
- n: int = 2,
- exact: bool = False,
- ):
- x.realize_hint() # elements are reused
- x_loader = x.make_loader()
- i_sizes = x.get_size()[-n:]
- batch = x.get_size()[:-n]
- i_sizes = [V.graph.sizevars.guard_int(i) for i in i_sizes]
- assert len(scales_x) == n
- o_sizes = output_size
- inv_scales = [i / o for i, o in zip(i_sizes, o_sizes)]
- for i, scale in enumerate(scales_x):
- if scale is not None:
- inv_scales[i] = 1.0 / scale
- def scale_fn(x, scale, size):
- # Nearest Exact: input_index = round(scale * (output_index + 0.5) - 0.5)
- # = floor(scale * (output_index + 0.5))
- # Nearest: input_index = floor(scale * output_index)
- x = ops.index_expr(x, torch.float32)
- if exact:
- x = ops.add(x, ops.constant(0.5, torch.float32))
- x = ops.mul(x, ops.constant(scale, torch.float32))
- x = ops.to_dtype(x, torch.int32)
- return ops.indirect_indexing(x, size, check=False)
- def fn(idx):
- x = idx[-n:]
- b = idx[:-n]
- return x_loader(
- [*b, *[scale_fn(i, s, size) for i, s, size in zip(x, inv_scales, i_sizes)]]
- )
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=[*batch, *o_sizes],
- )
- @register_lowering(aten.upsample_nearest1d.default)
- def upsample_nearest1d(x, output_size, scales: Optional[float] = None):
- return upsample_nearestnd(x, output_size, (scales,), n=1)
- @register_lowering(aten._upsample_nearest_exact1d.default)
- def _upsample_nearest_exact1d(x, output_size, scales: Optional[float] = None):
- return upsample_nearestnd(x, output_size, (scales,), n=1, exact=True)
- @register_lowering(aten.upsample_nearest2d.default)
- def upsample_nearest2d(
- x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
- ):
- return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2)
- @register_lowering(aten._upsample_nearest_exact2d.default)
- def _upsample_nearest_exact2d(
- x, output_size, scales_h: Optional[float] = None, scales_w: Optional[float] = None
- ):
- return upsample_nearestnd(x, output_size, (scales_h, scales_w), n=2, exact=True)
- @register_lowering(aten.upsample_nearest3d.default)
- def upsample_nearest3d(
- x,
- output_size,
- scales_d: Optional[float] = None,
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ):
- return upsample_nearestnd(x, output_size, (scales_d, scales_h, scales_w), n=3)
- @register_lowering(aten._upsample_nearest_exact3d.default)
- def _upsample_nearest_exact3d(
- x,
- output_size,
- scales_d: Optional[float] = None,
- scales_h: Optional[float] = None,
- scales_w: Optional[float] = None,
- ):
- return upsample_nearestnd(
- x, output_size, (scales_d, scales_h, scales_w), n=3, exact=True
- )
- def _create_constants(*args, dtype):
- return tuple(ops.constant(a, dtype) for a in args)
- @register_lowering(prims.rev.default)
- def rev(x, dims):
- # note - dims pre-canonicalized
- x_loader = x.make_loader()
- sizes = x.get_size()
- def loader(idx):
- idx = list(idx)
- assert len(idx) == len(sizes)
- for dim in dims:
- idx[dim] = (sizes[dim] - 1) - idx[dim]
- return x_loader(idx)
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=loader,
- ranges=sizes,
- )
- def inplace_constant_pad_nd(
- x: TensorBox, padding: Sequence[int], fill_value: float
- ) -> Optional[TensorBox]:
- """
- This optimization changes the semantics of padding from 'clone'
- style to 'view' style.
- Thanks to functionalization, this change can still maintain numerical
- correctness.
- """
- def _padding_can_be_fused():
- """
- Conservatively check if padding can be fused with downstream op.
- 1. if the downstream op is a sum, then there is little benefit to
- do inplace padding
- 2. if the downstream op is a matmul, doing inplace padding can
- save membw.
- """
- current_node = V.graph.current_node
- if current_node is None:
- return True # be conservative
- users = tuple(current_node.users)
- if len(users) == 1 and users[0].target in (
- aten.mm.default,
- aten.addmm.default,
- ):
- return False
- return True # be conservative
- if _padding_can_be_fused():
- return None
- # Only handle 2D case for now
- if len(padding) != 4 or len(x.get_size()) != 2:
- return None
- # No harm to realize since we already know that
- # the op can not be fused into the single user.
- # It need to be realized later anyways.
- x.realize()
- # If x is a view (e.g. a SliceView), realizing it just realizing the
- # underlying storage. x itself is still a view.
- if (
- not isinstance(x, ir.TensorBox)
- or not isinstance(x.data, ir.StorageBox)
- or not (
- isinstance(x.data.data, ir.ComputedBuffer)
- or (
- config.can_inplace_pad_graph_input
- and isinstance(x.data.data, ir.InputBuffer)
- )
- )
- or not x.data.data.name
- ):
- return None
- x.freeze_layout()
- _, layout = ir.as_storage_and_layout(x)
- strides = layout.stride
- if strides[1] != 1:
- return None
- if padding[0] != 0 or padding[2] != 0 or padding[3] != 0:
- return None
- npad = padding[1]
- if npad == 0:
- return None
- stride0 = strides[0]
- rowsize = layout.size[1]
- if stride0 < rowsize + npad:
- return None
- bufname = x.data.data.name
- padded_size = [layout.size[0], layout.size[1] + npad]
- V.graph.buffer_to_padded_size[bufname] = padded_size
- resized_x = as_strided(
- x,
- padded_size,
- layout.stride,
- layout.offset,
- )
- sliced_x = slice_(resized_x, dim=1, start=rowsize, end=rowsize + npad)
- fill_(sliced_x, fill_value)
- counters["inductor"]["inplace_padding"] += 1
- return resized_x
- @register_lowering(aten.constant_pad_nd, type_promotion_kind=None)
- def constant_pad_nd(x, padding, fill_value=0):
- assert (len(padding) % 2) == 0
- if all(p == 0 for p in padding):
- return clone(x)
- if config.inplace_padding:
- out = inplace_constant_pad_nd(x, padding, fill_value)
- if out:
- return out
- # fall through if can not inplace the padding
- sizes = x.get_size()
- bounds = list(reversed(list(zip(padding[::2], padding[1::2]))))
- n = len(sizes) - len(bounds)
- # if padding is a complicated expression, hoist it
- bounds_precomp: list[tuple[sympy.Symbol, Any]] = []
- for l, h in bounds:
- bounds_precomp.append((V.graph.sizevars.lookup_precomputed_size(l), h)) # type: ignore[arg-type]
- output_size = list(sizes[:n])
- mask_sizes = []
- for (low, high), size in zip(bounds, sizes[n:]):
- mask_sizes.append(size)
- output_size.append(sympy.expand(size + low + high))
- assert len(output_size) == len(sizes)
- fill_value = dtype_to_type(x.get_dtype())(fill_value)
- def mask(index):
- mask = []
- for idx, (low, high), length in zip(index[n:], bounds, mask_sizes):
- if low != 0:
- mask.append(range_mask_low(idx, 0))
- if high != 0:
- mask.append(range_mask_high(idx, length))
- mask = functools.reduce(ops.and_, mask)
- return ops.masked(mask, lambda: x_loader(index), fill_value)
- def offset_fn(index):
- new_index = list(index[:n])
- for idx, (low, _high) in zip(index[n:], bounds_precomp):
- new_index.append(idx - low)
- assert len(new_index) == len(index)
- return mask(new_index)
- x_loader = x.make_loader()
- return Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=offset_fn,
- ranges=output_size,
- )
- def range_mask_low(i: sympy.Expr, low: Union[sympy.Expr, int]):
- return ops.ge(
- ops.index_expr(i, torch.int64),
- ops.index_expr(sympy.Integer(low), torch.int64),
- )
- def range_mask_high(i: sympy.Expr, high: sympy.Expr):
- return ops.lt(
- ops.index_expr(i, torch.int64),
- ops.index_expr(high, torch.int64),
- )
- def range_mask(i: sympy.Expr, high: sympy.Expr, low: sympy.Expr):
- return ops.and_(
- range_mask_low(i, low),
- range_mask_high(i, high),
- )
- def constant_boundary_condition(
- x, fill_value, padding=None, pad_fill_value=1.0, dim=None
- ):
- h = x.get_size()[-dim:]
- x_loader = x.make_loader()
- padding_h = padding or [0] * dim
- def load(index):
- prefix = index[:-dim]
- ih = index[-dim:]
- mask = functools.reduce(
- ops.and_,
- [range_mask(ih[i], h[i] + padding_h[i], -padding_h[i]) for i in range(dim)],
- )
- return (
- ops.masked(
- mask,
- lambda: constant_boundary_condition(x, pad_fill_value, dim=dim)(
- [*prefix, *ih]
- ),
- fill_value,
- )
- if padding
- else ops.masked(mask, lambda: x_loader([*prefix, *ih]), fill_value)
- )
- return load
- def pooling_size(x, i, kernel_size, stride, padding, ceil_mode, *, dilation=None):
- if dilation is None:
- dilation = [1] * len(padding)
- x_out = FloorDiv(
- x + 2 * padding[i] - dilation[i] * (kernel_size[i] - 1) + (stride[i] - 1),
- stride[i],
- )
- if ceil_mode:
- x_alt = FloorDiv(
- x
- + 2 * padding[i]
- - dilation[i] * (kernel_size[i] - 1)
- + 2 * (stride[i] - 1),
- stride[i],
- )
- if V.graph.sizevars.size_hint((x_alt - 1) * stride[i] - x - padding[i]) >= 0:
- # Sliding windows must start within the input or left padding
- x_alt -= 1 # type: ignore[assignment]
- V.graph.sizevars.check_leq(0, x_alt * stride[i] - x - padding[i]) # type: ignore[arg-type]
- if V.graph.sizevars.size_hint(x_out - x_alt) == 0:
- # ceil mode is actually a no-op, lets guard on that
- V.graph.sizevars.check_equals(x_out, x_alt)
- ceil_mode = False
- else:
- x_out = x_alt
- return x_out, ceil_mode
- def should_fallback_max_pool_with_indices(kernel_size, *, n_dim):
- kernel_size = pad_listlike(kernel_size, n_dim)
- window_size = functools.reduce(operator.mul, kernel_size)
- return window_size > 25
- def max_pool_checks(
- x, kernel_size, stride, padding, dilation, n_dim, *, assert_fallback=None
- ):
- if padding == 0:
- padding = [0] * n_dim
- if dilation == 1:
- dilation = [1] * n_dim
- if not stride:
- stride = kernel_size
- kernel_size = pad_listlike(kernel_size, n_dim)
- stride = pad_listlike(stride, n_dim)
- padding = pad_listlike(padding, n_dim)
- dilation = pad_listlike(dilation, n_dim)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == n_dim
- assert len(stride) == n_dim
- assert len(padding) == n_dim
- assert len(dilation) == n_dim
- assert len(x.get_size()) in (n_dim + 1, n_dim + 2)
- use_fallback = should_fallback_max_pool_with_indices(kernel_size, n_dim=n_dim)
- if assert_fallback is not None:
- assert use_fallback == assert_fallback
- return kernel_size, stride, padding, dilation, use_fallback
- def _max_pool_with_offsets(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- ceil_mode,
- *,
- n_dim,
- ):
- x.realize_hint()
- batch = x.shape[:-n_dim]
- dhw = x.shape[-n_dim:]
- dhw_out, ceil_mode = zip(
- *[
- pooling_size(
- dhw[d], d, kernel_size, stride, padding, ceil_mode, dilation=dilation
- )
- for d in range(n_dim)
- ]
- )
- dtype = x.dtype
- min_value = (
- False
- if dtype is torch.bool
- else (float("-inf") if dtype.is_floating_point else torch.iinfo(dtype).min)
- )
- new_size = list(batch) + list(dhw_out)
- if any(padding) or any(ceil_mode) or any(d > 1 for d in dilation):
- x_loader = constant_boundary_condition(x, min_value, dim=n_dim)
- else:
- x_loader = x.make_loader()
- def fn_inner(idx, reduction_idx):
- prefix = idx[:-n_dim]
- bh = idx[-n_dim:]
- ih = [
- (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
- for i in range(n_dim)
- ]
- return x_loader([*prefix, *ih])
- result = Reduction.create(
- reduction_type="max",
- input_node=x,
- device=x.get_device(),
- dst_dtype=dtype,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- offsets = Reduction.create(
- reduction_type="argmax",
- input_node=x,
- device=x.get_device(),
- dst_dtype=torch.int64,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- if isinstance(result.data.data, Reduction): # type: ignore[attr-defined, union-attr]
- # Only realize if reduction isn't unrolled
- result.realize()
- if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined, union-attr]
- # Only realize if reduction isn't unrolled
- offsets.realize()
- return result, offsets
- @register_lowering(prims._low_memory_max_pool_with_offsets, type_promotion_kind=None)
- def _low_memory_max_pool_with_offsets(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- ceil_mode=False,
- ):
- n_dim = len(kernel_size)
- # assert we are not on a fallback path, the inductor decomp should have guaranteed this
- kernel_size, stride, padding, dilation, _ = max_pool_checks(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- n_dim,
- assert_fallback=False,
- )
- with config.patch(unroll_reductions_threshold=25):
- result, offsets = _max_pool_with_offsets(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- ceil_mode,
- n_dim=n_dim,
- )
- return result, to_dtype(offsets, torch.int8)
- def _pool_offsets_to_indices(
- offsets: TensorBox,
- kernel_size: Sequence[Union[int, torch.SymInt]],
- input_size: Sequence[Union[int, torch.SymInt]],
- increments_to_index: Callable[
- [Sequence[Union[int, torch.SymInt]], Sequence[Union[int, torch.SymInt]]],
- torch._inductor.virtualized.OpsValue,
- ],
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- n_dim = len(kernel_size)
- offsets_loader = offsets.make_loader()
- window_size = sympy.sympify(functools.reduce(operator.mul, kernel_size))
- def offsets_to_indices(idx):
- offset = offsets_loader(idx)
- offset_sympy = ops.indirect_indexing(offset, window_size)
- reduction_idx = inductor_prims._flattened_index_to_nd(offset_sympy, kernel_size)
- idhw = increments_to_index(idx, reduction_idx)
- return ops.index_expr(
- inductor_prims._flatten_index(idhw, input_size[-n_dim:]), torch.int64
- )
- indices = Pointwise.create(
- device=offsets.get_device(),
- dtype=torch.int64,
- inner_fn=offsets_to_indices,
- ranges=offsets.get_size(),
- )
- return indices
- @register_lowering(
- prims._low_memory_max_pool_offsets_to_indices, type_promotion_kind=None
- )
- def _low_memory_max_pool_offsets_to_indices(
- offsets, kernel_size, input_size, stride, padding, dilation
- ):
- # TODO: Generalize to other max pooling flavors
- n_dim = len(kernel_size)
- def increments_to_index(idx, reduction_idx):
- bh = idx[-n_dim:]
- return [
- (bh[i] * stride[i]) + (reduction_idx[i] * dilation[i]) - padding[i]
- for i in range(n_dim)
- ]
- return _pool_offsets_to_indices(
- offsets, kernel_size, input_size, increments_to_index
- )
- def _max_pool_with_indices(
- x,
- kernel_size,
- stride,
- padding,
- dilation,
- ceil_mode,
- n_dim,
- ):
- kernel_size, stride, padding, dilation, _ = max_pool_checks(
- x, kernel_size, stride, padding, dilation, n_dim=n_dim
- )
- out, offsets = _max_pool_with_offsets(
- x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=n_dim
- )
- indices = _low_memory_max_pool_offsets_to_indices(
- offsets,
- kernel_size,
- x.shape[-n_dim:],
- stride,
- padding,
- dilation,
- )
- return out, indices
- # Fallback when we do not decompose to the low-memory path.
- @register_lowering(aten.max_pool2d_with_indices, type_promotion_kind=None)
- def max_pool2d_with_indices(
- x,
- kernel_size,
- stride=None,
- padding=0,
- dilation=1,
- ceil_mode=False,
- ):
- return _max_pool_with_indices(
- x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=2
- )
- # Fallback when we do not decompose to the low-memory path.
- @register_lowering(aten.max_pool3d_with_indices, type_promotion_kind=None)
- def max_pool3d_with_indices(
- x,
- kernel_size,
- stride=None,
- padding=0,
- dilation=1,
- ceil_mode=False,
- ):
- return _max_pool_with_indices(
- x, kernel_size, stride, padding, dilation, ceil_mode, n_dim=3
- )
- fallback_max_pool2d_with_indices_backward = fallback_handler(
- aten.max_pool2d_with_indices_backward.default,
- add_to_fallback_set=False,
- )
- @register_lowering(aten.max_pool2d_with_indices_backward, type_promotion_kind=None)
- def max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- ):
- if padding == 0:
- padding = [0, 0]
- if dilation == 1:
- dilation = [1, 1]
- if not stride:
- stride = kernel_size
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 2
- assert len(stride) == 2
- assert len(padding) == 2
- assert len(dilation) == 2
- assert len(x.get_size()) in (3, 4)
- # we will read this many times, so make sure it is computed
- grad_output.realize_hint()
- gO_stride = grad_output.maybe_get_stride()
- x_stride: Optional[Sequence[Any]]
- if isinstance(x, TensorBox) and isinstance(x.data.data, Pointwise): # type: ignore[attr-defined]
- data = x.data.data # type: ignore[attr-defined]
- device = data.get_device()
- assert device is not None
- x_buffer = ir.ComputedBuffer(
- name=None,
- layout=ir.FlexibleLayout(
- device=device,
- dtype=data.get_dtype(),
- size=data.get_size(),
- ),
- data=data,
- )
- x_buffer.decide_layout()
- x_stride = x_buffer.get_stride()
- else:
- x_stride = x.maybe_get_stride()
- is_channels_last = (x_stride is not None and x_stride[1] == 1) or (
- gO_stride is not None and gO_stride[1] == 1
- )
- if any(d != 1 for d in dilation):
- # dilation NYI
- return fallback_max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- )
- *_batch, _height, width = x.get_size()
- *_, pooled_height, pooled_width = grad_output.get_size()
- indices_loader = indices.make_loader()
- grad_loader = grad_output.make_loader()
- new_size = list(x.get_size())
- h_window_size = max(
- max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
- for h in range(kernel_size[0] * 2)
- )
- w_window_size = max(
- max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
- for w in range(kernel_size[1] * 2)
- )
- window_size = h_window_size * w_window_size
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_max_pool2d_with_indices_backward(
- grad_output, x, kernel_size, stride, padding, dilation, ceil_mode, indices
- )
- indices_size = indices.get_size()
- def fn(idx):
- *prefix, h, w = idx
- index_test = ops.index_expr(h * width + w, torch.int32)
- h = h + padding[0]
- w = w + padding[1]
- phstart = ops.index_expr(
- FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
- )
- pwstart = ops.index_expr(
- FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
- )
- phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
- pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
- phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
- pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
- phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
- pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
- gradient = None
- for ph_ in range(h_window_size):
- for pw_ in range(w_window_size):
- ph = ops.add(phstart, ops.constant(ph_, torch.int32))
- pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
- grad_index = [
- *prefix,
- ops.indirect_indexing(
- ops.minimum(ph, ops.sub(phend, ops.constant(1, torch.int32))),
- indices_size[-2],
- check=False,
- ),
- ops.indirect_indexing(
- ops.minimum(pw, ops.sub(pwend, ops.constant(1, torch.int32))),
- indices_size[-1],
- check=False,
- ),
- ]
- index_actual = indices_loader(grad_index)
- grad_part = grad_loader(grad_index)
- check = ops.eq(index_actual, index_test)
- if gradient is None:
- # don't need mask for 0, 0
- gradient = ops.where(
- check, grad_part, ops.constant(0.0, torch.float32)
- )
- else:
- mask = ops.and_(
- ops.and_(
- ops.lt(ph, phend),
- ops.lt(pw, pwend),
- ),
- check,
- )
- gradient = ops.where(mask, ops.add(gradient, grad_part), gradient)
- assert gradient is not None
- return gradient
- out = Pointwise.create(
- device=grad_output.get_device(),
- dtype=grad_output.get_dtype(),
- inner_fn=fn,
- ranges=new_size,
- )
- if is_channels_last:
- return ir.ExternKernel.require_channels_last(out)
- else:
- return out
- def pad_adaptive_loader(x, pad_val=0.0):
- x_loader = x.make_loader()
- def load(prefix, increments, start_indices, end_indices):
- ih, iw = increments
- h_start_index, w_start_index = start_indices
- h_end_index, w_end_index = end_indices
- mask = ops.and_(
- ops.lt(
- ops.index_expr(h_start_index + ih, torch.int64),
- ops.index_expr(h_end_index, torch.int64),
- ),
- ops.lt(
- ops.index_expr(w_start_index + iw, torch.int64),
- ops.index_expr(w_end_index, torch.int64),
- ),
- )
- return ops.masked(
- mask,
- lambda: x_loader([*prefix, h_start_index + ih, w_start_index + iw]),
- pad_val,
- )
- return load
- def compute_indices_adaptive_pooling(start_index, end_index, h_in, w_in, h_out, w_out):
- h_start_index = functools.partial(start_index, out_dim=h_out, inp_dim=h_in)
- h_end_index = functools.partial(end_index, out_dim=h_out, inp_dim=h_in)
- w_start_index = functools.partial(start_index, out_dim=w_out, inp_dim=w_in)
- w_end_index = functools.partial(end_index, out_dim=w_out, inp_dim=w_in)
- return h_start_index, h_end_index, w_start_index, w_end_index
- def _adaptive_pooling_fn(
- start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
- ):
- h_in, w_in = in_sizes
- h_out, w_out = out_sizes
- (
- h_start_index_fn,
- h_end_index_fn,
- w_start_index_fn,
- w_end_index_fn,
- ) = compute_indices_adaptive_pooling(
- start_index, end_index, h_in, w_in, h_out, w_out
- )
- def fn(idx, loader):
- *prefix, bh, bw = idx
- h_start_index = h_start_index_fn(bh)
- h_end_index = h_end_index_fn(bh)
- w_start_index = w_start_index_fn(bw)
- w_end_index = w_end_index_fn(bw)
- result = None
- for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
- val = loader(
- prefix,
- [ih, iw],
- [h_start_index, w_start_index],
- [h_end_index, w_end_index],
- )
- if result is None:
- result = val
- else:
- result = pooling_fn(val, result)
- return result
- return fn
- def _adaptive_pooling_fn_with_idx(
- start_index, end_index, kernel_maxes, in_sizes, out_sizes, pooling_fn
- ):
- h_in, w_in = in_sizes
- h_out, w_out = out_sizes
- (
- h_start_index_fn,
- h_end_index_fn,
- w_start_index_fn,
- w_end_index_fn,
- ) = compute_indices_adaptive_pooling(
- start_index, end_index, h_in, w_in, h_out, w_out
- )
- def fn(idx, loader):
- *prefix, bh, bw = idx
- h_start_index = h_start_index_fn(bh)
- h_end_index = h_end_index_fn(bh)
- w_start_index = w_start_index_fn(bw)
- w_end_index = w_end_index_fn(bw)
- maxval = None
- maxindex = None
- for ih, iw in itertools.product(range(kernel_maxes[0]), range(kernel_maxes[1])):
- val = loader(
- prefix,
- [ih, iw],
- [h_start_index, w_start_index],
- [h_end_index, w_end_index],
- )
- index = ops.index_expr(
- (h_start_index + ih) * w_in + w_start_index + iw, torch.int64
- )
- if maxindex is None:
- maxindex = index
- else:
- maxindex = ops.where(ops.gt(val, maxval), index, maxindex)
- if maxval is None:
- maxval = val
- else:
- maxval = pooling_fn(val, maxval)
- return maxindex
- return fn
- fallback_adaptive_avg_pool2d = fallback_handler(
- aten._adaptive_avg_pool2d.default, add_to_fallback_set=False
- )
- @register_lowering(aten._adaptive_avg_pool2d)
- def _adaptive_avg_pool2d(x, output_size):
- if x.get_dtype() == torch.int64:
- # not supported in eager
- raise RuntimeError("'adaptive_avg_pool2d' not implemented for 'Long'")
- assert isinstance(x, TensorBox)
- assert len(output_size) == 2
- x.realize_hint()
- *batch, h_in, w_in = x.get_size()
- h_in = V.graph.sizevars.guard_int(h_in)
- w_in = V.graph.sizevars.guard_int(w_in)
- h_out, w_out = output_size
- # no-op if the same input and output
- if h_in == h_out and w_in == w_out:
- return clone(x)
- if h_out == 0 or w_out == 0:
- o_size = [*batch, h_out, w_out]
- return empty(o_size, dtype=x.get_dtype(), device=x.get_device())
- if h_in % h_out == 0 and w_in % w_out == 0:
- kernel_size = [h_in // h_out, w_in // w_out]
- return avg_pool2d(x, kernel_size)
- h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
- w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
- new_size = list(batch) + [h_out, w_out]
- dtype = x.get_dtype()
- window_size = h_kernel_max * w_kernel_max
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_adaptive_avg_pool2d(x, output_size)
- def start_index(index, out_dim, inp_dim):
- return FloorDiv((index * inp_dim), out_dim)
- def end_index(index, out_dim, inp_dim):
- return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
- fn_sum = _adaptive_pooling_fn(
- start_index=start_index,
- end_index=end_index,
- kernel_maxes=[h_kernel_max, w_kernel_max],
- in_sizes=[h_in, w_in],
- out_sizes=[h_out, w_out],
- pooling_fn=ops.add,
- )
- ones_loader = pad_adaptive_loader(ones_like(x))
- def fn(idx):
- return ops.truediv(
- fn_sum(idx, pad_adaptive_loader(x)), fn_sum(idx, ones_loader)
- )
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- # TODO: should we force these to be realized?
- return rv
- fallback_adaptive_max_pool2d = fallback_handler(
- aten.adaptive_max_pool2d.default, add_to_fallback_set=False
- )
- @register_lowering(aten.adaptive_max_pool2d)
- def adaptive_max_pool2d(x, output_size):
- if x.get_dtype() == torch.int64:
- # not supported in eager
- raise RuntimeError("adaptive_max_pool2d not implemented for Long")
- assert isinstance(x, TensorBox)
- assert len(output_size) == 2
- x.realize_hint()
- *batch, h_in, w_in = x.get_size()
- h_in = V.graph.sizevars.guard_int(h_in)
- w_in = V.graph.sizevars.guard_int(w_in)
- h_out, w_out = output_size
- if h_out == 0 or w_out == 0:
- o_size = [*batch, h_out, w_out]
- return empty(o_size, dtype=x.get_dtype(), device=x.get_device()), empty(
- o_size, dtype=torch.int64, device=x.get_device()
- )
- if h_in % h_out == 0 and w_in % w_out == 0:
- # This is handled by a decomposition
- raise ValueError
- h_kernel_max = ceildiv((h_in + h_out - 1), h_out)
- w_kernel_max = ceildiv((w_in + w_out - 1), w_out)
- new_size = list(batch) + [h_out, w_out]
- dtype = x.get_dtype()
- window_size = h_kernel_max * w_kernel_max
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_adaptive_max_pool2d(x, output_size)
- def start_index(index, out_dim, inp_dim):
- return FloorDiv((index * inp_dim), out_dim)
- def end_index(index, out_dim, inp_dim):
- return FloorDiv((index + 1) * inp_dim + out_dim - 1, out_dim)
- inner_func_max_val = _adaptive_pooling_fn(
- start_index=start_index,
- end_index=end_index,
- kernel_maxes=[h_kernel_max, w_kernel_max],
- in_sizes=[h_in, w_in],
- out_sizes=[h_out, w_out],
- pooling_fn=ops.maximum,
- )
- inner_func_max_idx = _adaptive_pooling_fn_with_idx(
- start_index=start_index,
- end_index=end_index,
- kernel_maxes=[h_kernel_max, w_kernel_max],
- in_sizes=[h_in, w_in],
- out_sizes=[h_out, w_out],
- pooling_fn=ops.maximum,
- )
- def inner_fn_max_val(idx):
- return inner_func_max_val(idx, pad_adaptive_loader(x, float("-inf")))
- def inner_fn_max_idx(idx):
- return inner_func_max_idx(idx, pad_adaptive_loader(x, float("-inf")))
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=dtype,
- inner_fn=inner_fn_max_val,
- ranges=new_size,
- )
- ri = Pointwise.create(
- device=x.get_device(),
- dtype=torch.int64,
- inner_fn=inner_fn_max_idx,
- ranges=new_size,
- )
- return rv, ri
- def _fractional_pooling_offsets(samples, in_sz, out_sz, kernel_sz, dim, ndims):
- out_sz = out_sz[dim]
- in_sz = in_sz[dim]
- kernel_sz = kernel_sz[dim]
- samples_loader = samples.make_loader()
- def load(prefix, i):
- # Handle indexing for samples tensor correctly for different input dimensions
- # samples tensor always has shape (N, C, 2) for fractional_max_pool2d where:
- # - N=1 for 3D inputs (C,H,W), N=batch_size for 4D inputs (N,C,H,W)
- # - C=num_channels
- # - 2 for the two spatial dimensions (height, width)
- samples_shape = samples.get_size()
- if len(samples_shape) == 3: # Expected: (N, C, 2)
- if len(prefix) == 1:
- # 3D input case: prefix=(channel,), samples=(1, C, 2)
- # Access: samples[0, channel, dim]
- sample = samples_loader([0, prefix[0], ndims - 1 - dim])
- elif len(prefix) >= 2:
- # 4D+ input case: prefix=(batch, channel, ...), samples=(batch, C, 2)
- # Access: samples[batch, channel, dim]
- sample = samples_loader([prefix[0], prefix[1], ndims - 1 - dim])
- else:
- # Edge case - shouldn't happen for valid fractional pooling
- sample = samples_loader([0, 0, ndims - 1 - dim])
- else:
- # Fallback for unexpected tensor shapes
- sample = samples_loader([*prefix, ndims - 1 - dim])
- i_expr = ops.index_expr(i, samples.get_dtype())
- diff = ops.index_expr(in_sz - kernel_sz, torch.int64)
- out_sz_expr = ops.index_expr(out_sz - 1, torch.int64)
- alpha = ops.truediv(
- ops.to_dtype(diff, torch.float64), ops.to_dtype(out_sz_expr, torch.float64)
- )
- alpha = ops.where(ops.eq(out_sz_expr, 0), 0, alpha)
- seq_i = ops.trunc((i_expr + sample) * alpha) - ops.trunc(sample * alpha)
- seq_i = ops.to_dtype(seq_i, torch.int64)
- mask = ops.lt(i_expr, out_sz_expr)
- return ops.indirect_indexing(ops.where(mask, seq_i, diff), sympy.sympify(in_sz))
- return load
- @register_lowering(aten.fractional_max_pool2d)
- def fractional_max_pool2d(x, kernel_size, output_size, random_samples):
- return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=2)
- @register_lowering(aten.fractional_max_pool3d)
- def fractional_max_pool3d(x, kernel_size, output_size, random_samples):
- return _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim=3)
- def _fractional_max_pool(x, kernel_size, output_size, random_samples, n_dim):
- x.realize_hint()
- batch, inp_dhw = x.shape[:-n_dim], x.shape[-n_dim:]
- with config.patch(unroll_reductions_threshold=25):
- dhw_index_fn = [
- _fractional_pooling_offsets(
- samples=random_samples,
- in_sz=inp_dhw,
- out_sz=output_size,
- kernel_sz=kernel_size,
- ndims=n_dim,
- dim=d,
- )
- for d in range(n_dim)
- ]
- x_loader = x.make_loader()
- def fn_inner(idx, reduction_idx):
- prefix = idx[:-n_dim]
- return x_loader([*prefix, *increments_to_index(idx, reduction_idx)])
- def increments_to_index(idx, reduction_idx):
- prefix = idx[:-n_dim]
- bdhw = idx[-n_dim:]
- return [
- dhw_index_fn[d](prefix, bdhw[d]) + reduction_idx[d]
- for d in range(n_dim)
- ]
- new_size = list(batch) + list(output_size)
- dtype = x.get_dtype()
- result = Reduction.create(
- reduction_type="max",
- input_node=x,
- device=x.get_device(),
- dst_dtype=dtype,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- offsets = Reduction.create(
- reduction_type="argmax",
- input_node=x,
- device=x.get_device(),
- dst_dtype=torch.int64,
- src_dtype=dtype,
- inner_fn=fn_inner,
- ranges=new_size,
- reduction_ranges=kernel_size,
- )
- assert isinstance(result, TensorBox), result
- if isinstance(result.data.data, Reduction): # type: ignore[attr-defined]
- # Only realize if reduction isn't unrolled
- result.realize()
- assert isinstance(offsets, TensorBox), offsets
- if isinstance(offsets.data.data, Reduction): # type: ignore[attr-defined]
- # Only realize if reduction isn't unrolled
- offsets.realize()
- indices = _pool_offsets_to_indices(
- offsets, kernel_size, x.shape, increments_to_index
- )
- return result, indices
- @register_lowering(aten.upsample_nearest2d_backward.default)
- def upsample_nearest2d_backward(
- x, output_size=None, input_size=None, scales_h=None, scales_w=None
- ):
- x.realize_hint()
- *_batch, inp_h, inp_w = x.get_size()
- inp_h = V.graph.sizevars.guard_int(inp_h)
- inp_w = V.graph.sizevars.guard_int(inp_w)
- *_batch, out_h, out_w = input_size
- if inp_h % out_h == 0 and inp_w % out_w == 0:
- return avg_pool2d(x, [inp_h // out_h, inp_w // out_w], divisor_override=1)
- h_kernel_max = ceildiv(inp_h, out_h)
- w_kernel_max = ceildiv(inp_w, out_w)
- def start_index(index, out_dim, inp_dim):
- return CeilDiv(index * inp_dim, sympy.sympify(out_dim))
- def end_index(index, out_dim, inp_dim):
- return start_index((index + 1), out_dim, inp_dim)
- fn_sum = _adaptive_pooling_fn(
- start_index=start_index,
- end_index=end_index,
- kernel_maxes=[h_kernel_max, w_kernel_max],
- in_sizes=[inp_h, inp_w],
- out_sizes=[out_h, out_w],
- pooling_fn=ops.add,
- )
- def fn(idx):
- return fn_sum(idx, pad_adaptive_loader(x))
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=fn,
- ranges=list(input_size),
- )
- return rv
- fallback_avg_pool2d = fallback_handler(
- aten.avg_pool2d.default, add_to_fallback_set=False
- )
- fallback_avg_pool3d = fallback_handler(
- aten.avg_pool3d.default, add_to_fallback_set=False
- )
- @register_lowering(aten.avg_pool2d, type_promotion_kind=None)
- def avg_pool2d(
- x,
- kernel_size,
- stride=(),
- padding=0,
- ceil_mode=False,
- count_include_pad=True,
- divisor_override=None,
- ):
- return _avg_poolnd(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- dim=2,
- )
- @register_lowering(aten.avg_pool3d, type_promotion_kind=None)
- def avg_pool3d(
- x,
- kernel_size,
- stride=(),
- padding=0,
- ceil_mode=False,
- count_include_pad=True,
- divisor_override=None,
- ):
- return _avg_poolnd(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- dim=3,
- )
- def _avg_poolnd(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- dim,
- ):
- if not stride:
- stride = kernel_size
- if not padding:
- padding = [0] * dim
- kernel_size = pad_listlike(kernel_size, dim)
- stride = pad_listlike(stride, dim)
- padding = pad_listlike(padding, dim)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == dim
- assert len(stride) == dim
- assert len(padding) == dim
- assert len(x.get_size()) in (dim + 1, dim + 2)
- x.realize_hint()
- batch = x.get_size()[:-dim]
- h = x.get_size()[-dim:]
- h_out, ceil_modes = zip(
- *[
- pooling_size(h[i], i, kernel_size, stride, padding, ceil_mode)
- for i in range(dim)
- ]
- )
- if any(padding) or any(ceil_modes):
- x_loader = constant_boundary_condition(x, 0.0, dim=dim)
- had_padding = True
- else:
- x_loader = x.make_loader()
- had_padding = False
- new_size = list(batch) + list(h_out)
- dtype = x.get_dtype()
- window_size = functools.reduce(operator.mul, kernel_size)
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- if dim == 2:
- fallback = fallback_avg_pool2d
- elif dim == 3:
- fallback = fallback_avg_pool3d
- else:
- raise ValueError(f"Unknown dim: {dim}")
- return fallback(
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- )
- def fn_sum(idx, loader):
- prefix = idx[:-dim]
- b = idx[-dim:]
- total = None
- for ih in itertools.product(*[range(kernel_size[i]) for i in range(dim)]):
- inp = [b[i] * stride[i] + ih[i] - padding[i] for i in range(dim)]
- val = loader([*prefix, *inp])
- if total is None:
- total = val
- else:
- total = ops.add(val, total)
- return total
- if not had_padding or divisor_override:
- divisor = divisor_override if divisor_override else window_size
- if dtype.is_floating_point:
- scale = 1 / divisor
- def fn(idx):
- return ops.mul(fn_sum(idx, x_loader), ops.constant(scale, dtype))
- else:
- def fn(idx):
- # C style integer division as done in native/cpu/AvgPoolKernel.cpp
- return ops.truncdiv(fn_sum(idx, x_loader), ops.constant(divisor, dtype))
- else:
- def fn(idx):
- bh = idx[-dim:]
- divide_factors = []
- for i in range(dim):
- hstart = bh[i] * stride[i] - padding[i]
- hend = sympy.Min(hstart + kernel_size[i], h[i] + padding[i])
- if not count_include_pad:
- hstart = sympy.Max(hstart, 0)
- hend = sympy.Min(hend, h[i])
- factor = ops.index_expr(hend - hstart, torch.int32)
- divide_factors.append(factor)
- divide_factor = functools.reduce(ops.mul, divide_factors)
- if dtype.is_floating_point:
- return ops.truediv(fn_sum(idx, x_loader), divide_factor)
- # C style integer division as done in native/cpu/AvgPoolKernel.cpp
- return ops.truncdiv(fn_sum(idx, x_loader), divide_factor)
- rv = Pointwise.create(
- device=x.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- # TODO(jansel): should we force these to be realized?
- return rv
- fallback_avg_pool2d_backward = fallback_handler(
- aten.avg_pool2d_backward.default, add_to_fallback_set=False
- )
- @register_lowering(aten.avg_pool2d_backward, type_promotion_kind=None)
- def avg_pool2d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override=None,
- ):
- assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
- if not stride:
- stride = kernel_size
- if not padding:
- padding = [0, 0]
- assert isinstance(grad_output, TensorBox)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 2
- assert len(stride) == 2
- assert len(padding) == 2
- assert len(x.get_size()) in (3, 4)
- grad_output.realize_hint() # we will read this many times, so make sure it is computed
- *_, height, width = x.get_size()
- _h_out, ceil_mode1 = pooling_size(
- height, 0, kernel_size, stride, padding, ceil_mode
- )
- _w_out, ceil_mode2 = pooling_size(width, 1, kernel_size, stride, padding, ceil_mode)
- grad_loader = grad_output.make_loader()
- had_padding = padding[0] or padding[1] or ceil_mode1 or ceil_mode2
- *_, pooled_height, pooled_width = grad_output.get_size()
- new_size = list(x.get_size())
- dtype = x.get_dtype()
- h_window_size = max(
- max(h // stride[0] - max(0, (h - kernel_size[0]) // stride[0]), 1)
- for h in range(kernel_size[0] * 2)
- )
- w_window_size = max(
- max(w // stride[1] - max(0, (w - kernel_size[1]) // stride[1]), 1)
- for w in range(kernel_size[1] * 2)
- )
- window_size = h_window_size * w_window_size
- if window_size > 25:
- # Kernel size too big. Results in hard-to-optimize Triton code. Use fallback.
- return fallback_avg_pool2d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- )
- def compute_pool_size_without_padding(ph, pw):
- """
- This computes the scaling factor that we will divide an element
- by when `count_include_pad=False`
- """
- stride_h = ops.constant(stride[0], torch.int32)
- stride_w = ops.constant(stride[1], torch.int32)
- pad_h = ops.constant(padding[0], torch.int32)
- pad_w = ops.constant(padding[1], torch.int32)
- kernel_h = ops.constant(kernel_size[0], torch.int32)
- kernel_w = ops.constant(kernel_size[1], torch.int32)
- hstart = ops.sub(ops.mul(ph, stride_h), pad_h)
- wstart = ops.sub(ops.mul(pw, stride_w), pad_w)
- hend = ops.minimum(
- ops.add(hstart, kernel_h),
- ops.add(ops.index_expr(height, torch.int32), pad_h),
- )
- wend = ops.minimum(
- ops.add(wstart, kernel_w),
- ops.add(ops.index_expr(width, torch.int32), pad_w),
- )
- hstart = ops.maximum(hstart, ops.constant(0, torch.int32))
- wstart = ops.maximum(wstart, ops.constant(0, torch.int32))
- hend = ops.minimum(hend, ops.index_expr(height, torch.int32))
- wend = ops.minimum(wend, ops.index_expr(width, torch.int32))
- divide_factor = ops.mul(ops.sub(hend, hstart), ops.sub(wend, wstart))
- return divide_factor
- def fn(idx):
- *prefix, h, w = idx
- h = h + padding[0]
- w = w + padding[1]
- phstart = ops.index_expr(
- FloorDiv(h - kernel_size[0] + stride[0], stride[0]), torch.int32
- )
- pwstart = ops.index_expr(
- FloorDiv(w - kernel_size[1] + stride[1], stride[1]), torch.int32
- )
- phend = ops.index_expr(FloorDiv(h, stride[0]) + 1, torch.int32)
- pwend = ops.index_expr(FloorDiv(w, stride[1]) + 1, torch.int32)
- phstart = ops.maximum(phstart, ops.constant(0, torch.int32))
- pwstart = ops.maximum(pwstart, ops.constant(0, torch.int32))
- phend = ops.minimum(phend, ops.index_expr(pooled_height, torch.int32))
- pwend = ops.minimum(pwend, ops.index_expr(pooled_width, torch.int32))
- gradient = None
- for ph_ in range(h_window_size):
- for pw_ in range(w_window_size):
- ph = ops.add(phstart, ops.constant(ph_, torch.int32))
- pw = ops.add(pwstart, ops.constant(pw_, torch.int32))
- if divisor_override is not None:
- scale = divisor_override
- elif count_include_pad or not had_padding:
- scale = kernel_size[0] * kernel_size[1]
- else:
- scale = compute_pool_size_without_padding(ph, pw)
- part = ops.truediv(
- grad_loader(
- [
- *prefix,
- ops.indirect_indexing(
- ops.minimum(
- ph, ops.sub(phend, ops.constant(1, torch.int32))
- ),
- pooled_height,
- check=False,
- ),
- ops.indirect_indexing(
- ops.minimum(
- pw, ops.sub(pwend, ops.constant(1, torch.int32))
- ),
- pooled_width,
- check=False,
- ),
- ]
- ),
- scale,
- )
- mask = ops.and_(
- ops.lt(ph, phend),
- ops.lt(pw, pwend),
- )
- if gradient is None:
- gradient = ops.where(mask, part, ops.constant(0.0, torch.float32))
- else:
- gradient = ops.where(mask, ops.add(gradient, part), gradient)
- assert gradient is not None
- return gradient
- rv = Pointwise.create(
- device=grad_output.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- return rv
- fallback_avg_pool3d_backward = fallback_handler(
- aten.avg_pool3d_backward.default, add_to_fallback_set=False
- )
- @register_lowering(aten.avg_pool3d_backward, type_promotion_kind=None)
- def avg_pool3d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override=None,
- ):
- assert divisor_override is None or divisor_override != 0, "divisor must be not zero"
- if not stride:
- stride = kernel_size
- if not padding:
- padding = [0, 0, 0]
- assert isinstance(grad_output, TensorBox)
- assert isinstance(x, TensorBox)
- assert len(kernel_size) == 3
- assert len(stride) == 3
- assert len(padding) == 3
- assert len(x.get_size()) in (4, 5)
- grad_output.realize_hint()
- *_batch, depth, height, width = x.get_size()
- _d_out, ceil_mode_d = pooling_size(
- depth, 0, kernel_size, stride, padding, ceil_mode
- )
- _h_out, ceil_mode_h = pooling_size(
- height, 1, kernel_size, stride, padding, ceil_mode
- )
- _w_out, ceil_mode_w = pooling_size(
- width, 2, kernel_size, stride, padding, ceil_mode
- )
- grad_loader = grad_output.make_loader()
- had_padding = any(padding) or ceil_mode_d or ceil_mode_h or ceil_mode_w
- *_, pooled_depth, pooled_height, pooled_width = grad_output.get_size()
- new_size = list(x.get_size())
- dtype = x.get_dtype()
- d_window_size, h_window_size, w_window_size = (
- max(
- max(d // stride[i] - max(0, (d - kernel_size[i]) // stride[i]), 1)
- for d in range(kernel_size[i] * 2)
- )
- for i in range(3)
- )
- window_size = d_window_size * h_window_size * w_window_size
- if window_size > 125:
- # Kernel size too big. Results in hard-to-optimize Triton code.
- return fallback_avg_pool3d_backward(
- grad_output,
- x,
- kernel_size,
- stride,
- padding,
- ceil_mode,
- count_include_pad,
- divisor_override,
- )
- def compute_pool_size_without_padding(pd, ph, pw):
- stride_d, stride_h, stride_w = (ops.constant(s, torch.int32) for s in stride)
- pad_d, pad_h, pad_w = (ops.constant(p, torch.int32) for p in padding)
- kernel_d, kernel_h, kernel_w = (
- ops.constant(k, torch.int32) for k in kernel_size
- )
- dstart, hstart, wstart = (
- ops.sub(ops.mul(p, s), pad)
- for p, s, pad in zip(
- [pd, ph, pw], [stride_d, stride_h, stride_w], [pad_d, pad_h, pad_w]
- )
- )
- dend, hend, wend = (
- ops.minimum(
- ops.add(start, k), ops.add(ops.index_expr(dim, torch.int32), pad)
- )
- for start, k, dim, pad in zip(
- [dstart, hstart, wstart],
- [kernel_d, kernel_h, kernel_w],
- [depth, height, width],
- [pad_d, pad_h, pad_w],
- )
- )
- dstart, hstart, wstart = (
- ops.maximum(start, ops.constant(0, torch.int32))
- for start in [dstart, hstart, wstart]
- )
- dend, hend, wend = (
- ops.minimum(end, ops.index_expr(dim, torch.int32))
- for end, dim in zip([dend, hend, wend], [depth, height, width])
- )
- divide_factor = ops.mul(
- ops.mul(ops.sub(dend, dstart), ops.sub(hend, hstart)), ops.sub(wend, wstart)
- )
- return divide_factor
- def fn(idx):
- *prefix, d, h, w = idx
- d, h, w = (v + pad for v, pad in zip([d, h, w], padding))
- pdstart, phstart, pwstart = (
- ops.index_expr(FloorDiv(v - k + s, s), torch.int32)
- for v, k, s in zip([d, h, w], kernel_size, stride)
- )
- pdend, phend, pwend = (
- ops.index_expr(FloorDiv(v, s) + 1, torch.int32)
- for v, s in zip([d, h, w], stride)
- )
- pdstart, phstart, pwstart = (
- ops.maximum(pstart, ops.constant(0, torch.int32))
- for pstart in [pdstart, phstart, pwstart]
- )
- pdend, phend, pwend = (
- ops.minimum(pend, ops.index_expr(pooled_dim, torch.int32))
- for pend, pooled_dim in zip(
- [pdend, phend, pwend], [pooled_depth, pooled_height, pooled_width]
- )
- )
- gradient = None
- # Iterate over the 3D region to accumulate gradients
- for pd_ in range(d_window_size):
- for ph_ in range(h_window_size):
- for pw_ in range(w_window_size):
- pd, ph, pw = (
- ops.add(pstart, ops.constant(p_, torch.int32))
- for pstart, p_ in zip(
- [pdstart, phstart, pwstart], [pd_, ph_, pw_]
- )
- )
- if divisor_override is not None:
- scale = divisor_override
- elif count_include_pad or not had_padding:
- scale = kernel_size[0] * kernel_size[1] * kernel_size[2]
- else:
- scale = compute_pool_size_without_padding(pd, ph, pw)
- part = ops.truediv(
- grad_loader(
- [
- *prefix,
- ops.indirect_indexing(
- ops.minimum(
- pd, ops.sub(pdend, ops.constant(1, torch.int32))
- ),
- pooled_depth,
- check=False,
- ),
- ops.indirect_indexing(
- ops.minimum(
- ph, ops.sub(phend, ops.constant(1, torch.int32))
- ),
- pooled_height,
- check=False,
- ),
- ops.indirect_indexing(
- ops.minimum(
- pw, ops.sub(pwend, ops.constant(1, torch.int32))
- ),
- pooled_width,
- check=False,
- ),
- ]
- ),
- scale,
- )
- mask = ops.and_(
- ops.and_(ops.lt(pd, pdend), ops.lt(ph, phend)),
- ops.lt(pw, pwend),
- )
- if gradient is None:
- gradient = ops.where(
- mask, part, ops.constant(0.0, torch.float32)
- )
- else:
- gradient = ops.where(mask, ops.add(gradient, part), gradient)
- assert gradient is not None
- return gradient
- rv = Pointwise.create(
- device=grad_output.get_device(),
- dtype=dtype,
- inner_fn=fn,
- ranges=new_size,
- )
- return rv
- def _validate_reduction_axis(x, axis):
- size = x.get_size()
- if isinstance(axis, int):
- axis = [axis]
- elif not axis:
- axis = range(len(size))
- if len(size) == 0:
- assert tuple(axis) in [(), (0,), (-1,)], f"invalid axis: {axis}"
- return []
- axis = list(axis)
- for i in range(len(axis)):
- if axis[i] < 0:
- axis[i] += len(size) if len(size) else 1
- assert 0 <= axis[i] < len(size) or (len(size) == 0 and axis[i] == 0)
- assert len(OrderedSet(axis)) == len(axis), "reduction axis not unique"
- return axis
- def _make_reduction_inner(x, *, axis, keepdims, dtype, override_return_dtype):
- if dtype is not None:
- x = to_dtype(x, dtype)
- size = x.get_size()
- axis = OrderedSet[int](_validate_reduction_axis(x, axis))
- kept_sizes = []
- kept_idx = []
- reduced_sizes = []
- reduced_idx = []
- for i in range(len(size)):
- if i in axis:
- reduced_idx.append(i)
- reduced_sizes.append(size[i])
- else:
- kept_idx.append(i)
- kept_sizes.append(size[i])
- def loader(index, reduction_index):
- assert len(reduction_index) == len(reduced_idx)
- if keepdims:
- assert len(index) == len(size)
- index = [index[i] for i in kept_idx]
- assert len(index) == len(kept_idx)
- new_index = [None] * (len(index) + len(reduction_index))
- for idx, var in itertools.chain(
- zip(kept_idx, index), zip(reduced_idx, reduction_index)
- ):
- new_index[idx] = var
- return inner_loader(new_index)
- if keepdims:
- new_size = list(size)
- for i in reduced_idx:
- new_size[i] = sympy.S.One
- else:
- new_size = kept_sizes
- inner_loader = x.make_loader()
- return dict(
- device=x.get_device(),
- dst_dtype=override_return_dtype or x.get_dtype(),
- src_dtype=x.get_dtype(),
- inner_fn=loader,
- ranges=new_size,
- reduction_ranges=reduced_sizes,
- )
- def make_reduction(reduction_type: ReductionType, override_return_dtype=None):
- def inner(x, axis=None, keepdims=False, *, dtype=None):
- kwargs = _make_reduction_inner(
- x,
- axis=axis,
- keepdims=keepdims,
- dtype=dtype,
- override_return_dtype=override_return_dtype,
- )
- result = Reduction.create(reduction_type=reduction_type, input_node=x, **kwargs)
- if isinstance(
- result.data.data, # type: ignore[attr-defined, attr-type, union-attr]
- Reduction,
- ): # Only realize if reduction isn't unrolled
- result.realize()
- return result
- return inner
- def _make_scan_inner(x, *, axis, dtype):
- if dtype is not None:
- x = to_dtype(x, dtype)
- axis = _validate_dim(x, axis)
- return dict(
- device=x.get_device(),
- dtypes=(x.get_dtype(),),
- inner_fns=(x.make_loader(),),
- size=x.get_size(),
- axis=axis,
- )
- @register_lowering(aten.mean)
- def mean(x, axis=None, keepdim=False, *, dtype=None):
- if dtype is not None:
- x = to_dtype(x, dtype)
- size = x.get_size()
- axis = _validate_reduction_axis(x, axis)
- # compute in higher-precision until end of mean lowering
- output_dtype = x.get_dtype()
- if output_dtype in (torch.float16, torch.bfloat16):
- x = to_dtype(x, torch.float)
- sum_result = sum_(x, axis, keepdim)
- denom = sympy_product(size[i] for i in axis)
- denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
- denom = ExpandView.create(denom, list(sum_result.get_size()))
- return to_dtype(div(sum_result, denom), output_dtype)
- def var_mean_sum_(x, axis, correction, keepdim, return_mean):
- if correction is None:
- correction = 1
- size = x.get_size()
- axis = _validate_reduction_axis(x, axis)
- x_mean = mean(x, axis, keepdim=True)
- if return_mean:
- x_mean.realize()
- diffs = square(sub(x, x_mean))
- sum_result = sum_(diffs, axis, keepdim)
- denom = sympy_product(size[i] for i in axis)
- if correction:
- denom = sympy.Max(denom - correction, 0)
- denom = ir.IndexingConstant(index=denom, dtype=x.get_dtype(), device=x.get_device())
- denom = ExpandView.create(denom, list(sum_result.get_size()))
- x_var = div(sum_result, denom)
- if not return_mean:
- return (x_var,)
- x_mean = x_mean if keepdim else squeeze(x_mean, axis)
- return x_var, x_mean
- def use_two_step_variance(x, axis, keepdim):
- # Instead of unrolling welford, just unroll the simpler two-step var
- axis = _validate_reduction_axis(x, axis)
- kwargs = _make_reduction_inner(
- x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
- )
- ranges = kwargs["ranges"]
- reduction_numel = sympy_product(kwargs["reduction_ranges"])
- return (
- isinstance(reduction_numel, sympy.Integer)
- and int(reduction_numel) < config.unroll_reductions_threshold
- and sympy_product(ranges) != 1
- )
- def var_mean_welford_(x, axis, *, correction, keepdim, return_mean):
- if correction is None:
- correction = 1
- kwargs = _make_reduction_inner(
- x, axis=axis, keepdims=keepdim, dtype=None, override_return_dtype=None
- )
- loader = kwargs.pop("inner_fn")
- kwargs.pop("dst_dtype")
- kwargs.pop("src_dtype")
- mean, m2, _ = ir.WelfordReduction.create(
- inner_fns=(loader,),
- reduction_type="welford_reduce",
- dtype=x.get_dtype(),
- **kwargs,
- )
- m2.realize()
- dtype = x.get_dtype()
- size = x.get_size()
- axis = _validate_reduction_axis(x, axis)
- rnumel = sympy_product(size[i] for i in axis)
- def get_constant_or_index_expr(x, dtype):
- if isinstance(x, sympy.Expr) and not x.is_number:
- return ops.to_dtype(ops.index_expr(x, torch.int64), dtype)
- return ops.constant(x, dtype)
- def scale_fn(data):
- c = get_constant_or_index_expr(correction, dtype)
- N = get_constant_or_index_expr(rnumel, dtype)
- zero = ops.constant(0, dtype)
- return data / ops.maximum(zero, N - c)
- var = make_pointwise(scale_fn)(m2)
- if return_mean:
- mean.realize()
- return var, mean
- return (var,)
- def var_mean_helper_(x, *, axis, correction, keepdim, return_mean):
- out_dtype = x.get_dtype()
- compute_dtype = get_computation_dtype(out_dtype)
- x = to_dtype(x, compute_dtype, copy=False)
- kwargs = dict(
- x=x,
- axis=axis,
- correction=correction,
- keepdim=keepdim,
- return_mean=return_mean,
- )
- output = (
- var_mean_sum_(**kwargs)
- if use_two_step_variance(x, axis=axis, keepdim=keepdim)
- else var_mean_welford_(**kwargs)
- )
- output = tuple(to_dtype(x, out_dtype, copy=False) for x in output)
- return output[0] if not return_mean else output
- @register_lowering([aten.var, prims.var])
- def var_(x, axis=None, *, correction=None, keepdim=False):
- return var_mean_helper_(
- x, axis=axis, correction=correction, keepdim=keepdim, return_mean=False
- )
- @register_lowering(aten.var_mean)
- def var_mean(x, axis=None, *, correction=None, keepdim=False):
- return var_mean_helper_(
- x, axis=axis, correction=correction, keepdim=keepdim, return_mean=True
- )
- def pow_recursive(x, y, dtype):
- if y < 0:
- return pow_recursive(ops.reciprocal(x), -y, dtype)
- if y == 0:
- return ops.constant(1, dtype)
- if y == 1:
- return x
- result = pow_recursive(x, y // 2, dtype)
- result = ops.mul(result, result)
- if (y % 2) == 1:
- result = ops.mul(result, x)
- return result
- @make_pointwise
- def pow_native(a, b):
- return ops.pow(a, b)
- fallback_pow_tensor_tensor = fallback_handler(
- aten.pow.Tensor_Tensor, add_to_fallback_set=False
- )
- fallback_pow_scalar = fallback_handler(aten.pow.Scalar, add_to_fallback_set=False)
- fallback_pow_tensor_scalar = fallback_handler(
- aten.pow.Tensor_Scalar, add_to_fallback_set=False
- )
- @register_lowering(aten.pow, broadcast=True)
- def pow(a, b):
- if isinstance(b, float) and b == int(b):
- return pow(a, int(b))
- elif isinstance(b, float) and b == 0.5:
- return sqrt(a)
- elif isinstance(b, int) and b == 1:
- return clone(a)
- # Type promotion ensures all tensor arguments have the same type
- dtype = next(x.get_dtype() for x in (a, b) if isinstance(x, ir.TensorBox))
- is_integer_pow = is_integer_dtype(dtype)
- # Optimize away small fixed powers, or for integers avoid falling back to ATen
- embed_exponent = isinstance(b, int) and (
- -32 < b < 32 or (is_integer_pow and b >= 0)
- )
- if embed_exponent:
- loader = a.make_loader()
- def fn(idx):
- return pow_recursive(loader(idx), b, a.get_dtype())
- return Pointwise.create(
- device=a.get_device(),
- dtype=a.get_dtype(),
- inner_fn=fn,
- ranges=a.get_size(),
- )
- if isinstance(a, Number):
- if a == 1:
- return full_like(b, 1)
- if a == 2 and is_float_dtype(b.get_dtype()):
- return exp2(b)
- if is_integer_pow:
- # ops.pow doesn't work for integers
- if isinstance(a, Number):
- return fallback_pow_scalar(a, b)
- elif isinstance(b, Number):
- return fallback_pow_tensor_scalar(a, b)
- else:
- return fallback_pow_tensor_tensor(a, b)
- return pow_native(a, b)
- def mutate_to(changed, val, unsafe_alias=False):
- if isinstance(changed, TensorBox):
- changed_data = changed.data
- else:
- changed_data = changed
- if isinstance(val, TensorBox):
- val = val.data
- if not isinstance(val, ir.StorageBox):
- # introduce a copy to handle views
- node = Pointwise.create(
- device=changed.get_device(),
- dtype=changed.get_dtype(),
- inner_fn=val.make_loader(),
- ranges=changed.get_size(),
- )
- assert isinstance(node, (BaseView, MutableBox))
- val = node.data
- assert isinstance(val, ir.StorageBox)
- if isinstance(changed_data, ir.StorageBox) and not (
- changed_data.is_input_buffer()
- # In AOTI, module parameters and buffers are not lifted as graph inputs
- or changed_data.is_module_buffer()
- or isinstance(changed_data.data, ir.NopKernel)
- ):
- # Fast path, just swing the data pointer
- val.realize()
- changed_data.data = val.data
- return changed
- ir.MutationLayoutSHOULDREMOVE.realize_into(
- val, changed_data, unsafe_alias=unsafe_alias
- )
- return changed
- @register_lowering(aten.fill_)
- def fill_(x, fill_value):
- return mutate_to(x, full_like(x, fill_value))
- @register_lowering(aten.copy_, type_promotion_kind=None)
- def copy_(dst, src, non_blocking=False):
- if dst is src:
- # dst.copy_(dst) can happen from the reinplacing pass
- return dst
- src = to_device(src, dst.get_device())
- src = to_dtype(src, dst.get_dtype())
- src = expand(src, dst.get_size())
- return mutate_to(dst, src)
- @make_pointwise
- def floordiv(a, b):
- return ops.floordiv(a, b)
- @make_pointwise
- def truncdiv(a, b):
- return ops.truncdiv(a, b)
- @register_lowering(aten.div, broadcast=True)
- def div_mode(a, b, rounding_mode=None):
- both_integer = is_integer_type(a) and is_integer_type(b)
- both_boolean = is_boolean_type(a) and is_boolean_type(b)
- # floordiv and truncdiv need special handling for integer tensors on Triton,
- # see the discussion at https://github.com/triton-lang/triton/issues/605
- if rounding_mode == "floor":
- assert not both_boolean, "floordiv operands can not be boolean at the same time"
- return floordiv(a, b) if both_integer else floor(div(a, b))
- if rounding_mode == "trunc":
- assert not both_boolean, "truncdiv operands can not be boolean at the same time"
- return truncdiv(a, b) if both_integer else trunc(div(a, b))
- return div(a, b)
- @register_lowering([aten.mul], broadcast=True)
- def mul(a, b):
- both_bool = is_boolean_type(a) and is_boolean_type(b)
- if both_bool:
- return logical_and(a, b)
- else:
- fn = ops_wrapper(aten.mul.__name__)
- return make_pointwise(fn)(a, b)
- def get_constant_value(x: ir.IRNode) -> Optional[ir.Constant]:
- """Try convert an arbitrary IR node into an ir.Constant value"""
- # First try unwrapping the IRNode to see if it is already an ir.Constant
- # Optional step, but avoids unnecessary inner_fn evaluation.
- if isinstance(x, ir.MutableBox):
- return get_constant_value(x.data)
- if isinstance(x, ir.BaseView):
- return get_constant_value(x.unwrap_view())
- if isinstance(x, ir.Constant):
- return x
- # If the unwrapped node is not an ir.Constant, try evaluating inner_fn
- # to see if the returned value is from an `ops.constant` call
- if not isinstance(x, ir.Loops):
- return None
- handler = torch._inductor.ops_handler.ExtractConstantsHandler(x.get_device())
- with (
- V.set_ops_handler(handler),
- patch.object(ir.FlexibleLayout, "allow_indexing", True),
- ):
- out = x.inner_fn(*x.inner_fn_args())
- assert isinstance(out, torch._inductor.virtualized.OpsValue)
- if isinstance(out.value, ir.Constant):
- return out.value
- return None
- # NOTE: prims.div maps to a / b in C, so performs truncation division on
- # integer inputs and true division for floating and complex inputs.
- @register_lowering([prims.div], broadcast=True)
- def div_prim(a, b):
- is_integral = all(is_boolean_type(x) or is_integer_type(x) for x in [a, b])
- if is_integral:
- return truncdiv(a, b)
- # Disable CPU optimization to avoid precision issues.
- # see https://github.com/pytorch/pytorch/issues/157959
- if (divisor := get_constant_value(b)) is not None and a.get_device().type != "cpu":
- # Replace divide by constant with multiply by reciprocal
- if divisor.value == 0:
- reciprocal = math.copysign(float("inf"), divisor.value)
- else:
- reciprocal = 1.0 / divisor.value
- return mul(a, reciprocal)
- def fn(*args):
- return ops.truediv(*args)
- return make_pointwise(fn)(a, b)
- @register_lowering(
- [aten.true_divide, aten.div.Tensor],
- broadcast=True,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- def div(a, b):
- a, b = promote_constants(
- (a, b), type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
- )
- return div_prim(a, b)
- @register_lowering([aten.fmod, prims.fmod], broadcast=True)
- def fmod(a, b):
- is_integral = is_boolean_type(a) or is_integer_type(a)
- if is_integral:
- def fn(a, b):
- return ops.mod(a, b)
- else:
- def fn(a, b):
- return ops.fmod(a, b)
- return make_pointwise(fn)(a, b)
- @register_lowering([aten.sum, prims.sum])
- def sum_(x, axis=None, keepdims=False, *, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- fn = make_reduction("sum", override_return_dtype=dtype)
- return fn(x, axis, keepdims, dtype=dtype)
- fallback_cumsum = fallback_handler(aten.cumsum.default)
- fallback_cumprod = fallback_handler(aten.cumprod.default)
- fallback_logcumsumexp = fallback_handler(aten.logcumsumexp.default)
- fallback_cummax = fallback_handler(aten.cummax.default)
- fallback_cummin = fallback_handler(aten.cummin.default)
- @register_lowering(aten.cumsum)
- def cumsum(x, axis=None, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- if len(x.get_size()) == 0:
- assert axis in [0, -1]
- dtype = dtype or x.get_dtype()
- return to_dtype(x, dtype, copy=True)
- def combine_fn(a_tuple, b_tuple):
- (a,) = a_tuple
- (b,) = b_tuple
- return (ops.add(a, b),)
- kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
- (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
- if result is None:
- return fallback_cumsum(x, dim=axis, dtype=dtype)
- return result
- @register_lowering(aten.cumprod)
- def cumprod(x, axis=None, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- if len(x.get_size()) == 0:
- assert axis in [0, -1]
- dtype = dtype or x.get_dtype()
- return to_dtype(x, dtype, copy=True)
- def combine_fn(a_tuple, b_tuple):
- (a,) = a_tuple
- (b,) = b_tuple
- return (ops.mul(a, b),)
- kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
- (result,) = ir.Scan.create(**kwargs, combine_fn=combine_fn)
- if result is None:
- return fallback_cumprod(x, dim=axis, dtype=dtype)
- return result
- @register_lowering(aten.logcumsumexp)
- def logcumsumexp(x, dim):
- def log_add_exp_helper(a_tuple, b_tuple):
- (a,) = a_tuple
- (b,) = b_tuple
- min_v = ops.minimum(a, b)
- max_v = ops.maximum(a, b)
- mask = (min_v != max_v) | (~ops.isinf(min_v))
- return (ops.where(mask, ops.log1p(ops.exp(min_v - max_v)) + max_v, a),)
- dtype = x.get_dtype()
- if len(x.get_size()) == 0:
- assert dim in [0, -1]
- return clone(x)
- kwargs = _make_scan_inner(x, axis=dim, dtype=dtype)
- (result,) = ir.Scan.create(**kwargs, combine_fn=log_add_exp_helper)
- if result is None:
- return fallback_logcumsumexp(x, dim=dim)
- return result
- @register_lowering(aten.cummax, type_promotion_kind=None)
- def cummax(x, axis=None):
- if len(x.get_size()) == 0:
- assert axis in [0, -1]
- return clone(x), empty_like(x, dtype=torch.int64)
- dtype = x.get_dtype()
- combine_fn = ir.get_reduction_combine_fn(
- "argmax", dtype=dtype, arg_break_ties_left=False
- )
- kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
- kwargs["dtypes"] = (dtype, torch.int64)
- kwargs["inner_fns"] = (
- x.make_loader(),
- lambda idx: ops.index_expr(idx[axis], torch.int64),
- )
- values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type]
- if values is None:
- return fallback_cummax(x, dim=axis)
- return values, indices
- @register_lowering(aten.cummin, type_promotion_kind=None)
- def cummin(x, axis=None):
- if len(x.get_size()) == 0:
- assert axis in [0, -1]
- return clone(x), empty_like(x, dtype=torch.int64)
- dtype = x.get_dtype()
- combine_fn = ir.get_reduction_combine_fn(
- "argmin", dtype=dtype, arg_break_ties_left=False
- )
- kwargs = _make_scan_inner(x, axis=axis, dtype=dtype)
- kwargs["dtypes"] = (dtype, torch.int64)
- kwargs["inner_fns"] = (
- x.make_loader(),
- lambda idx: ops.index_expr(idx[axis], torch.int64),
- )
- values, indices = ir.Scan.create(**kwargs, combine_fn=combine_fn) # type: ignore[arg-type]
- if values is None:
- return fallback_cummin(x, dim=axis)
- return values, indices
- @register_lowering(aten.prod)
- def prod(x, axis=None, keepdims=False, *, dtype=None):
- if (
- is_integer_dtype(x.get_dtype()) or is_boolean_dtype(x.get_dtype())
- ) and dtype is None:
- dtype = torch.int64
- fn = make_reduction("prod", override_return_dtype=dtype)
- return fn(x, axis, keepdims, dtype=dtype)
- @register_lowering(aten.any)
- def reduce_any(x, dim=None, keepdim=False):
- x = to_dtype(x, torch.bool)
- return make_reduction("any")(x, axis=dim, keepdims=keepdim)
- @register_lowering(aten.max, type_promotion_kind=None)
- def reduce_max(x, dim=None, keepdim=False):
- if dim is not None:
- return (
- reduce_amax(x, axis=dim, keepdims=keepdim),
- reduce_argmax(x, axis=dim, keepdims=keepdim),
- )
- return reduce_amax(x, axis=None, keepdims=keepdim)
- @register_lowering(aten.min, type_promotion_kind=None)
- def reduce_min(x, dim=None, keepdim=False):
- if dim is not None:
- return (
- reduce_amin(x, axis=dim, keepdims=keepdim),
- reduce_argmin(x, axis=dim, keepdims=keepdim),
- )
- return reduce_amin(x, axis=None, keepdims=keepdim)
- register_lowering(prims.xor_sum)(make_reduction("xor_sum"))
- reduce_amax = register_lowering(aten.amax)(make_reduction("max"))
- reduce_amin = register_lowering(aten.amin)(make_reduction("min"))
- reduce_argmax = register_lowering(aten.argmax)(
- make_reduction("argmax", override_return_dtype=torch.int64)
- )
- reduce_argmin = register_lowering(aten.argmin)(
- make_reduction("argmin", override_return_dtype=torch.int64)
- )
- add = register_pointwise(
- aten.add, allow_alpha=True, override_fn_when_input_bool="logical_or"
- )
- sort_fallback = fallback_handler(aten.sort.stable, add_to_fallback_set=False)
- @register_lowering(aten.sort.stable, type_promotion_kind=None)
- def sort_stable(x, *, stable=None, dim=-1, descending=False):
- if stable is None:
- stable = False
- shape = x.get_size()
- device = x.get_device()
- dim = canonicalize_dim(len(shape), dim)
- if len(shape) == 0:
- return clone(x), _full(0, device, torch.int64, shape)
- dim_size = shape[dim] if len(shape) else 1
- if not V.graph.sizevars.statically_known_lt(dim_size, torch.iinfo(torch.int16).max):
- return sort_fallback(x, stable=stable, dim=dim, descending=descending)
- indices = iota(
- dim_size, start=0, step=1, dtype=torch.int16, device=device, requires_grad=False
- )
- view_shape = [1] * len(shape)
- if len(shape):
- view_shape[dim] = dim_size
- indices = view(indices, view_shape)
- indices = expand(indices, shape)
- values, indices = ir.Sort.create(
- device=device,
- dtypes=(x.dtype, indices.dtype),
- inner_fns=(x.make_loader(), indices.make_loader()),
- size=shape,
- axis=dim,
- stable=stable,
- descending=descending,
- )
- if values is None:
- return sort_fallback(x, stable=stable, dim=dim, descending=descending)
- assert indices is not None
- return values, to_dtype(indices, torch.int64)
- @register_lowering(aten.sort.default, type_promotion_kind=None)
- def sort(x, dim=-1, descending=False):
- return sort_stable(x, stable=False, dim=dim, descending=descending)
- def register_pointwise_numeric(op, name=None, triton_fallback=None):
- return register_pointwise(
- op,
- name=name,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- triton_fallback=triton_fallback,
- )
- def register_pointwise_numeric_ldf64(op: torch._ops.OpOverloadPacket):
- register_op_requires_libdevice_fp64(op.__name__)
- return register_pointwise(
- op,
- type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
- )
- rsqrt = register_pointwise_numeric(aten.rsqrt)
- exp = register_pointwise_numeric_ldf64(aten.exp)
- exp2 = register_pointwise_numeric(aten.exp2)
- expm1 = register_pointwise_numeric(aten.expm1)
- relu = register_pointwise(aten.relu)
- sigmoid = register_pointwise_numeric_ldf64(aten.sigmoid)
- sqrt = register_pointwise_numeric_ldf64(aten.sqrt)
- square = register_pointwise(aten.square)
- sub = register_pointwise(aten.sub, allow_alpha=True)
- register_pointwise_numeric_ldf64(aten.cos)
- register_pointwise_numeric_ldf64(aten.sin)
- abs = register_pointwise(aten.abs)
- bitwise_and = register_pointwise(aten.bitwise_and)
- bitwise_left_shift = register_pointwise(aten.bitwise_left_shift)
- bitwise_not = register_pointwise(
- aten.bitwise_not, override_fn_when_input_bool="logical_not"
- )
- bitwise_or = register_pointwise(aten.bitwise_or)
- bitwise_right_shift = register_pointwise(aten.bitwise_right_shift)
- bitwise_xor = register_pointwise(aten.bitwise_xor)
- register_pointwise_numeric(aten.lgamma)
- erf = register_pointwise_numeric(aten.erf)
- register_lowering(
- aten.special_erf, type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
- )(erf)
- register_pointwise_numeric(aten.log1p)
- register_pointwise_numeric(aten.tan)
- register_pointwise_numeric(aten.tanh)
- register_pointwise_numeric_ldf64(aten.log)
- logical_and = register_pointwise(
- aten.logical_and,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- logical_not = register_pointwise(
- aten.logical_not,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- logical_or = register_pointwise(
- aten.logical_or,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- logical_xor = register_pointwise(
- aten.logical_xor,
- type_promotion_kind=None,
- convert_input_to_bool=True,
- override_return_dtype=torch.bool,
- )
- maximum = register_pointwise(aten.maximum)
- minimum = register_pointwise(aten.minimum)
- register_lowering(aten.clamp_min)(maximum)
- register_lowering(aten.clamp_max)(minimum)
- neg = register_pointwise(aten.neg)
- abs = register_pointwise(aten.abs)
- reciprocal = register_pointwise_numeric(aten.reciprocal)
- register_pointwise(aten.remainder)
- sign = register_pointwise(aten.sign, override_fn_when_input_bool="identity")
- register_pointwise(aten.ceil)
- register_pointwise(aten.signbit, override_return_dtype=torch.bool)
- register_lowering(aten._neg_view)(neg)
- register_pointwise(aten.le, override_return_dtype=torch.bool)
- register_pointwise(aten.lt, override_return_dtype=torch.bool)
- register_pointwise(aten.ge, override_return_dtype=torch.bool)
- gt = register_pointwise(aten.gt, override_return_dtype=torch.bool)
- register_pointwise(aten.eq, override_return_dtype=torch.bool)
- register_pointwise(aten.ne, override_return_dtype=torch.bool)
- register_pointwise_numeric(aten.cosh)
- register_pointwise_numeric(aten.sinh)
- register_pointwise_numeric(aten.acos)
- register_pointwise_numeric(aten.acosh)
- register_pointwise_numeric(aten.asin)
- register_pointwise_numeric(aten.asinh)
- register_pointwise_numeric(aten.atan2)
- register_pointwise_numeric(aten.atan)
- register_pointwise_numeric(aten.atanh)
- register_pointwise_numeric(aten.copysign)
- register_pointwise_numeric(aten.erfc)
- register_pointwise_numeric(aten.erfinv)
- register_pointwise_numeric(aten.hypot)
- register_pointwise_numeric(aten.log10)
- register_pointwise_numeric(aten.log2)
- register_pointwise_numeric(aten.nextafter)
- from .codegen.common import BackendFeature, pointwise_overrides_data
- def _get_pointwise_overrides(ns, name):
- data = pointwise_overrides_data[name]
- op = getattr(ns, data.name, None)
- if op is None:
- return
- def make_triton_fallback(op):
- if data.triton is None:
- return fallback_handler(op)
- if isinstance(op, torch._ops.OpOverloadPacket):
- for olname in op.overloads():
- ol = getattr(op, olname)
- yield ol, data.type_promotion_kind, make_triton_fallback(ol)
- else:
- yield op, data.type_promotion_kind, make_triton_fallback(op)
- for name in pointwise_overrides_data:
- for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
- aten, name
- ):
- register_pointwise(
- op,
- name=name,
- type_promotion_kind=type_promotion_kind,
- triton_fallback=triton_fallback,
- )
- for op, type_promotion_kind, triton_fallback in _get_pointwise_overrides(
- prims, name
- ):
- register_pointwise(
- op,
- name=name,
- type_promotion_kind=type_promotion_kind,
- triton_fallback=triton_fallback,
- )
- foreach_add_list = register_foreach_pointwise(
- aten._foreach_add.List, add, allow_alpha=True
- )
- foreach_add_scalar = register_foreach_pointwise(
- aten._foreach_add.Scalar, add, allow_alpha=True
- )
- register_foreach_pointwise(aten._foreach_add.Tensor, add, allow_alpha=True)
- foreach_mul_list = register_foreach_pointwise(aten._foreach_mul.List, mul)
- register_foreach_pointwise(aten._foreach_mul.Tensor, mul)
- foreach_mul_scalar = register_foreach_pointwise(aten._foreach_mul.Scalar, mul)
- register_foreach_pointwise(aten._foreach_sub.List, sub)
- register_foreach_pointwise(aten._foreach_sub.Scalar, sub)
- register_foreach_pointwise(aten._foreach_neg.default, neg)
- register_foreach_pointwise(aten._foreach_abs.default, abs)
- register_foreach_pointwise(aten._foreach_pow.Scalar, pow)
- register_foreach_pointwise(aten._foreach_pow.List, pow)
- register_foreach_pointwise(aten._foreach_pow.ScalarAndTensor, pow)
- foreach_div_list = register_foreach_pointwise(aten._foreach_div.List, div)
- register_foreach_pointwise(aten._foreach_div.Tensor, div)
- foreach_div_scalar = register_foreach_pointwise(aten._foreach_div.Scalar, div)
- register_foreach_pointwise(aten._foreach_sqrt, sqrt)
- register_foreach_pointwise(aten._foreach_rsqrt, rsqrt)
- register_foreach_pointwise(aten._foreach_maximum.List, maximum)
- register_foreach_pointwise(aten._foreach_maximum.Scalar, maximum)
- register_foreach_pointwise(aten._foreach_minimum.List, minimum)
- register_foreach_pointwise(aten._foreach_minimum.Scalar, minimum)
- register_foreach_pointwise(aten._foreach_clamp_min.List, maximum)
- register_foreach_pointwise(aten._foreach_clamp_min.Scalar, maximum)
- register_foreach_pointwise(aten._foreach_clamp_max.List, minimum)
- register_foreach_pointwise(aten._foreach_clamp_max.Scalar, minimum)
- register_foreach_pointwise(aten._foreach_reciprocal, reciprocal)
- register_foreach_pointwise(aten._foreach_sign, sign)
- foreach_copy = register_foreach_pointwise(aten._foreach_copy, copy)
- # these are only encountered as outputs of the graph
- # reinplacing epilogue copies improves compile time
- # by removing extra buffers sent to the scheduler.
- def register_foreach_inplace(aten_op, outplace_aten_op, outplace_op):
- inplaceable_foreach_ops[outplace_aten_op] = aten_op
- inplace_foreach_ops.add(aten_op)
- def fn(*args, **kwargs):
- results = outplace_op(*args, **kwargs)
- mut_results = []
- for arg, result in zip(args[0], results):
- mut_results.append(mutate_to(arg, result, unsafe_alias=True))
- return mut_results
- _register_foreach_lowering(aten_op, fn)
- register_foreach_inplace(
- aten._foreach_add_.List, aten._foreach_add.List, foreach_add_list
- )
- register_foreach_inplace(
- aten._foreach_add_.Scalar, aten._foreach_add.Scalar, foreach_add_scalar
- )
- register_foreach_inplace(
- aten._foreach_mul_.List, aten._foreach_mul.List, foreach_mul_list
- )
- register_foreach_inplace(
- aten._foreach_mul_.Scalar, aten._foreach_mul.Scalar, foreach_mul_scalar
- )
- register_foreach_inplace(
- aten._foreach_div_.List, aten._foreach_div.List, foreach_div_list
- )
- register_foreach_inplace(
- aten._foreach_div_.Scalar, aten._foreach_div.Scalar, foreach_div_scalar
- )
- register_foreach_inplace(
- aten._foreach_copy_.default, aten._foreach_copy.default, foreach_copy
- )
- def register_inplace(aten_op, outplace_op):
- @register_lowering(aten_op, type_promotion_kind=None)
- def fn(*args, **kwargs):
- result = outplace_op(*args, **kwargs)
- result = to_dtype(result, args[0].get_dtype())
- return mutate_to(args[0], result)
- return fn
- register_inplace(aten.add_, add)
- register_inplace(aten.bitwise_and_, bitwise_and)
- register_inplace(aten.bitwise_left_shift_, bitwise_left_shift)
- register_inplace(aten.bitwise_not_, bitwise_not)
- register_inplace(aten.bitwise_or_, bitwise_or)
- register_inplace(aten.bitwise_right_shift_, bitwise_right_shift)
- register_inplace(aten.bitwise_xor_, bitwise_xor)
- register_inplace(aten.mul_, mul)
- register_inplace(aten.div_.Tensor, div)
- register_inplace(aten.div_.Tensor_mode, div_mode)
- register_inplace(aten.logical_and_, logical_and)
- register_inplace(aten.logical_not_, logical_not)
- register_inplace(aten.logical_or_, logical_or)
- register_inplace(aten.logical_xor_, logical_xor)
- register_inplace(aten.sub_, sub)
- register_inplace(aten.relu_, relu)
- register_inplace(aten.sigmoid_, sigmoid)
- register_lowering(aten.__and__)(bitwise_and)
- register_lowering(aten.__lshift__)(bitwise_left_shift)
- register_lowering(aten.__or__)(bitwise_or)
- register_lowering(aten.__rshift__)(bitwise_right_shift)
- register_lowering(aten.__xor__)(bitwise_xor)
- register_inplace(aten.__iand__, aten.__and__)
- register_inplace(aten.__ilshift__, aten.__lshift__)
- register_inplace(aten.__ior__, aten.__or__)
- register_inplace(aten.__irshift__, aten.__rshift__)
- register_inplace(aten.__ixor__, aten.__xor__)
- @register_lowering(aten.sym_constrain_range)
- def sym_constrain_range(a, min=None, max=None):
- return None
- @register_lowering(aten.sym_size.int)
- def sym_size(a, dim):
- val = V.graph.current_node.meta["val"]
- # Note [Can val be an int?]
- # ~~~~~~~~~~~~~~~~~~~~~~~~~
- # In principle, someone could construct an FX graph where
- # a call to size/stride has a val that is a plain int (not
- # SymInt). However, we will maintain the invariant that
- # this is not possible: if you are constructing an FX graph
- # where there is a call to size/stride that returns an
- # int, but you KNOW that int must always be a constant,
- # then you do not need trace that call at all (and just
- # constant propagate the integer as is.)
- assert isinstance(val, torch.SymInt), (
- f"Expect val to be torch.SymInt but got val={val}"
- )
- return val.node.expr
- @register_lowering(aten.sym_stride.int)
- def sym_stride(a, dim):
- val = V.graph.current_node.meta["val"]
- # See Note [Can val be an int?]
- assert isinstance(val, torch.SymInt), (
- f"Expect val to be torch.SymInt but got val={val}"
- )
- return val.node.expr
- @register_lowering(aten.sym_numel)
- def sym_numel(a):
- return a.get_numel()
- for method, func in magic_methods.items():
- register_lowering(method_to_operator(method))(func) # type: ignore[arg-type]
- @register_lowering(torch.sym_sum)
- def sym_sum(args):
- return sympy.Add(*args)
- @register_lowering(aten._foobar)
- def foobar(self, *args, **kwargs):
- raise NotImplementedError("Helpful for debugging")
- @register_lowering(torch.ops._inductor_test.realize)
- def _realize(x):
- x.realize()
- return clone(x)
- @register_lowering(torch.ops.inductor.resize_storage_bytes_)
- def resize_storage_bytes_(variable, new_size):
- variable.realize()
- ir.ResizeStorageBytes(variable, new_size)
- return variable
- @register_lowering(torch.ops.aten.set_.source_Tensor)
- def set__source_tensor(self, source_tensor):
- self.realize()
- source_tensor.realize()
- return TensorBox.create(ir.SetSourceTensorKernel(self, source_tensor))
- if hasattr(torch.ops.fsdp, "copy_"):
- @register_lowering(torch.ops.fsdp.copy_.default)
- def fsdp_copy_(dst, src):
- if dst is src:
- # dst.copy_(dst) can happen from the reinplacing pass
- return dst
- src = to_device(src, dst.get_device())
- src = to_dtype(src, dst.get_dtype())
- src = expand(src, dst.get_size())
- return mutate_to(dst, src)
- @register_lowering(torch.ops.aten.resize)
- def resize(x, size, *, memory_format=None):
- assert isinstance(x, TensorBox)
- assert isinstance(size, (list, tuple))
- if memory_format is None:
- memory_format = torch.contiguous_format
- if memory_format == torch.preserve_format:
- raise RuntimeError(f"unsupported memory format: {memory_format}")
- if memory_format == torch.channels_last:
- assert len(size) == 4
- if memory_format == torch.channels_last_3d:
- assert len(size) == 5
- old_numel = x.get_numel()
- dtype = x.get_dtype()
- device = x.get_device_or_error()
- if isinstance(x.data, ir.BaseView):
- x.data = x.data.unwrap_view()
- if (
- torch.are_deterministic_algorithms_enabled()
- and torch.utils.deterministic.fill_uninitialized_memory # type: ignore[attr-defined]
- ):
- if is_float_dtype(dtype):
- uninitialized_val = float("nan")
- elif is_integer_dtype(dtype):
- uninitialized_val = torch.iinfo(dtype).max
- else:
- uninitialized_val = True
- else:
- # using zero as that is what empty does
- uninitialized_val = 0.0
- if V.graph.sizevars.statically_known_equals(old_numel, 0): # type: ignore[arg-type]
- return full(size, uninitialized_val, dtype=dtype, device=device)
- x_flat = as_strided(
- x,
- [
- old_numel,
- ],
- [
- 1,
- ],
- )
- flat_loader = x_flat.make_loader()
- out_stride = ir.FlexibleLayout.stride_ordered_for_memory_format(size, memory_format)
- out_indexer = ir.FixedLayout(device, dtype, size, out_stride).make_indexer()
- def inner_fn(idx):
- flat_index = out_indexer(idx)
- flat_index_expr = ops.index_expr(flat_index, torch.int64)
- limit = ops.index_expr(old_numel, torch.int64)
- mask = ops.lt(flat_index_expr, limit)
- return ops.masked(mask, lambda: flat_loader([flat_index]), uninitialized_val)
- out = Pointwise.create(
- device=device, dtype=dtype, inner_fn=inner_fn, ranges=list(size)
- )
- return out
- from torch._higher_order_ops.auto_functionalize import auto_functionalized
- make_fallback(auto_functionalized)
- @register_lowering(triton_kernel_wrapper_mutation)
- def triton_kernel_wrap_(
- *,
- kernel_idx,
- constant_args_idx,
- grid,
- tma_descriptor_metadata,
- kwargs,
- ):
- from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
- constant_args = kernel_side_table.get_constant_args(constant_args_idx)
- ir.UserDefinedTritonKernel(
- kernel_idx=kernel_idx,
- grid=grid,
- tma_descriptor_metadata=tma_descriptor_metadata,
- kernel_args={**kwargs, **constant_args},
- )
- return {key: val for key, val in kwargs.items() if isinstance(val, TensorBox)}
- @register_lowering(torch.ops.higher_order.cond, type_promotion_kind=None)
- def cond(pred, true_fn, false_fn, operands):
- if any(isinstance(x, IRNode) and is_triton(x) for x in [pred, *operands]):
- msg = "control flow operator: torch.cond."
- if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
- msg = f"{msg} Found from : \n {stack_trace}"
- V.graph.disable_cudagraphs_reason = msg
- result = ir.Conditional.create(pred, true_fn, false_fn, operands)
- return list(map(TensorBox.create, result))
- @register_lowering(torch.ops.higher_order.while_loop, type_promotion_kind=None)
- def while_loop(cond_fn, body_fn, carried_inputs, additional_inputs, stack_output=False):
- if any(
- isinstance(x, IRNode) and is_triton(x)
- for x in carried_inputs + additional_inputs
- ):
- msg = "control flow operator: torch.while_loop."
- if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
- msg = f"{msg} Found from : \n {stack_trace}"
- V.graph.disable_cudagraphs_reason = msg
- def _map_output(out: Any):
- if isinstance(out, TensorBox):
- return out
- elif isinstance(out, ir.StorageBox):
- return TensorBox(out)
- elif isinstance(out, ir.MultiOutput):
- return TensorBox.create(out)
- else:
- raise RuntimeError(f"NYI unsupported output type: {type(out)}")
- result = ir.WhileLoop.create(
- cond_fn, body_fn, carried_inputs, additional_inputs, stack_output
- )
- assert isinstance(result, Sequence)
- return list(map(_map_output, result))
- register_lowering(
- torch.ops.higher_order.while_loop_stack_output, type_promotion_kind=None
- )(functools.partial(while_loop, stack_output=True))
- @register_lowering(torch.ops.higher_order.invoke_subgraph, type_promotion_kind=None)
- def invoke_subgraph(subgraph_fn: ir.Subgraph, identifier: str, *operands):
- result = ir.InvokeSubgraph.create(subgraph_fn, *operands)
- return list(map(TensorBox.create, result)) # type: ignore[call-overload]
- @register_lowering(torch._higher_order_ops.invoke_quant, type_promotion_kind=None)
- def invoke_quant_tracer(subgraph_fn: ir.Subgraph, *operands, scheme=None):
- output = None
- quant_options = V.graph.current_node.meta.get("quant_options", None)
- assert quant_options is not None
- for i, node in enumerate(subgraph_fn.graph_module.graph.nodes):
- if node.op == "placeholder":
- V.graph.env[node] = operands[i]
- continue
- # todo getattr
- elif node.op == "output":
- args, kwargs = V.graph.fetch_args_kwargs_from_env(node)
- for v in itertools.chain(args, kwargs.values()):
- v.realize()
- if quant_options.codegen_low_precision:
- V.graph.low_precision_codegen_ops.add(v.get_operation_name())
- V.graph.invoke_quant_ops.add(v.get_operation_name())
- output = torch.fx.Interpreter.output(V.graph, node, args, kwargs)
- else:
- V.graph.env[node] = V.graph.run_node(node)
- return output
- @register_lowering(associative_scan_op, type_promotion_kind=None)
- def associative_scan(
- combine_fn: ir.Subgraph, xs, additional_inputs: tuple[torch.Tensor]
- ):
- from .subgraph_lowering import InputDescriptor, lower_pointwise_subgraph
- if len(additional_inputs) > 0:
- raise RuntimeError(
- "Unable to generate code for associative_scan op, because there are lifted arguments"
- )
- subgraph_inputs = [
- InputDescriptor(dtype=x.get_dtype(), device=x.get_device())
- for x in itertools.chain(xs, xs)
- ]
- lowered_combine_fn = lower_pointwise_subgraph(combine_fn, subgraph_inputs) # type: ignore[var-annotated]
- def wrapped_combine_fn(lhs, rhs):
- return lowered_combine_fn(
- *pytree.tree_leaves(lhs),
- *pytree.tree_leaves(rhs),
- )
- kwargs = _make_scan_inner(xs[0], axis=0, dtype=None)
- kwargs["dtypes"] = tuple(x.get_dtype() for x in xs)
- kwargs["inner_fns"] = tuple(x.make_loader() for x in xs)
- result = ir.Scan.create(
- combine_fn=wrapped_combine_fn,
- can_fallback_to_aten=False,
- **kwargs,
- )
- if result[0] is None:
- raise RuntimeError("Unable to generate code for associative_scan op")
- return result
- @register_lowering(torch.ops.prims._sink_tokens.default)
- def _sink_tokens(tokens):
- return None
- @register_lowering(torch.ops.higher_order.with_effects, type_promotion_kind=None)
- def with_effects(token, op, *args, **kwargs):
- result = ir.EffectfulKernel.create(op, *args, **kwargs)
- from torch._higher_order_ops.effects import get_effect_key
- effect_type = get_effect_key(op, args, kwargs)
- assert effect_type is not None
- effectful_kernel = V.graph.effectful_ops[effect_type]
- if result is None:
- return (effectful_kernel,)
- result = pytree.tree_map_only(ir.MultiOutput, TensorBox.create, result)
- # See [NOTE: with_effects return type]
- # Only return `result` if it is a tuple, not list.
- if not isinstance(result, tuple):
- return (effectful_kernel, result)
- else:
- return (effectful_kernel, *result)
- from .comm_lowering import register_comm_lowerings
- register_comm_lowerings()
- @register_lowering(inductor_prims.prepare_softmax_online, type_promotion_kind=None)
- def prepare_softmax_online(x, dim):
- """
- Lowering inductor_prims.prepare_softmax_online to compute max/sum in one pass if no split is needed.
- """
- kwargs = _make_reduction_inner(
- x, axis=dim, keepdims=True, dtype=None, override_return_dtype=None
- )
- reduction_ranges = kwargs["reduction_ranges"]
- rnumel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
- hint, num_split = ir.Reduction.num_splits(
- **kwargs,
- reduction_type="online_softmax_reduce", # type: ignore[arg-type]
- reduction_numel=rnumel,
- )
- if (
- num_split == 1
- and V.graph.sizevars.size_hint(rnumel) >= config.unroll_reductions_threshold
- ):
- max_tensor, sum_tensor = OnlineSoftmaxReduction.create(
- input_node=x, num_output=2, reduction_hint=hint, **kwargs
- )
- return max_tensor, sum_tensor
- else:
- # Note: [Split online_softmax_reduce]
- # We don't split reduction for online_softmax_reduce for now.
- # On one hand, supporting split reduction makes things complex since
- # the split out reuctions requires 2 inputs rather than one.
- # On the other hand, during training the online_softmax_reduce should
- # usually don't requires a split due to large batch size
- # (more specifically batch size times sequence length).
- # We should support split reduction if we find legit use cases to
- # motivate the work.
- #
- # TODO: does inference need split online_softmax_reduce?
- warnings.warn(
- textwrap.dedent(
- """
- Online softmax is disabled on the fly since Inductor decides to
- split the reduction. Cut an issue to PyTorch if this is an
- important use case and you want to speed it up with online
- softmax.
- """
- )
- )
- amax = reduce_amax(x, dim, keepdims=True)
- exp = lowerings[aten.exp](sub(x, amax))
- xsum = sum_(exp, dim, keepdims=True)
- return amax, xsum
- # populate lowerings defined in kernel/*
- from . import kernel
- import_submodule(kernel)
- from . import quantized_lowerings
- quantized_lowerings.register_quantized_ops()
- quantized_lowerings.register_woq_mm_ops()
- from . import mkldnn_lowerings
- mkldnn_lowerings.register_onednn_fusion_ops()
- from . import jagged_lowerings
- jagged_lowerings.register_jagged_ops()
- @contextlib.contextmanager
- def force_fallback(op: torch._ops.OpOverload):
- """
- A context manager to force fallback an op. Used in unit test
- for FallbackKernel.
- """
- assert isinstance(op, torch._ops.OpOverload), (
- "Only OpOverload to make the clean up easier"
- )
- old_handler = lowerings.get(op)
- try:
- register_lowering(op)(fallback_handler(op))
- yield
- finally:
- if old_handler:
- lowerings[op] = old_handler
- else:
- lowerings.pop(op)
|