_meta_registrations.py 264 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209321032113212321332143215321632173218321932203221322232233224322532263227322832293230323132323233323432353236323732383239324032413242324332443245324632473248324932503251325232533254325532563257325832593260326132623263326432653266326732683269327032713272327332743275327632773278327932803281328232833284328532863287328832893290329132923293329432953296329732983299330033013302330333043305330633073308330933103311331233133314331533163317331833193320332133223323332433253326332733283329333033313332333333343335333633373338333933403341334233433344334533463347334833493350335133523353335433553356335733583359336033613362336333643365336633673368336933703371337233733374337533763377337833793380338133823383338433853386338733883389339033913392339333943395339633973398339934003401340234033404340534063407340834093410341134123413341434153416341734183419342034213422342334243425342634273428342934303431343234333434343534363437343834393440344134423443344434453446344734483449345034513452345334543455345634573458345934603461346234633464346534663467346834693470347134723473347434753476347734783479348034813482348334843485348634873488348934903491349234933494349534963497349834993500350135023503350435053506350735083509351035113512351335143515351635173518351935203521352235233524352535263527352835293530353135323533353435353536353735383539354035413542354335443545354635473548354935503551355235533554355535563557355835593560356135623563356435653566356735683569357035713572357335743575357635773578357935803581358235833584358535863587358835893590359135923593359435953596359735983599360036013602360336043605360636073608360936103611361236133614361536163617361836193620362136223623362436253626362736283629363036313632363336343635363636373638363936403641364236433644364536463647364836493650365136523653365436553656365736583659366036613662366336643665366636673668366936703671367236733674367536763677367836793680368136823683368436853686368736883689369036913692369336943695369636973698369937003701370237033704370537063707370837093710371137123713371437153716371737183719372037213722372337243725372637273728372937303731373237333734373537363737373837393740374137423743374437453746374737483749375037513752375337543755375637573758375937603761376237633764376537663767376837693770377137723773377437753776377737783779378037813782378337843785378637873788378937903791379237933794379537963797379837993800380138023803380438053806380738083809381038113812381338143815381638173818381938203821382238233824382538263827382838293830383138323833383438353836383738383839384038413842384338443845384638473848384938503851385238533854385538563857385838593860386138623863386438653866386738683869387038713872387338743875387638773878387938803881388238833884388538863887388838893890389138923893389438953896389738983899390039013902390339043905390639073908390939103911391239133914391539163917391839193920392139223923392439253926392739283929393039313932393339343935393639373938393939403941394239433944394539463947394839493950395139523953395439553956395739583959396039613962396339643965396639673968396939703971397239733974397539763977397839793980398139823983398439853986398739883989399039913992399339943995399639973998399940004001400240034004400540064007400840094010401140124013401440154016401740184019402040214022402340244025402640274028402940304031403240334034403540364037403840394040404140424043404440454046404740484049405040514052405340544055405640574058405940604061406240634064406540664067406840694070407140724073407440754076407740784079408040814082408340844085408640874088408940904091409240934094409540964097409840994100410141024103410441054106410741084109411041114112411341144115411641174118411941204121412241234124412541264127412841294130413141324133413441354136413741384139414041414142414341444145414641474148414941504151415241534154415541564157415841594160416141624163416441654166416741684169417041714172417341744175417641774178417941804181418241834184418541864187418841894190419141924193419441954196419741984199420042014202420342044205420642074208420942104211421242134214421542164217421842194220422142224223422442254226422742284229423042314232423342344235423642374238423942404241424242434244424542464247424842494250425142524253425442554256425742584259426042614262426342644265426642674268426942704271427242734274427542764277427842794280428142824283428442854286428742884289429042914292429342944295429642974298429943004301430243034304430543064307430843094310431143124313431443154316431743184319432043214322432343244325432643274328432943304331433243334334433543364337433843394340434143424343434443454346434743484349435043514352435343544355435643574358435943604361436243634364436543664367436843694370437143724373437443754376437743784379438043814382438343844385438643874388438943904391439243934394439543964397439843994400440144024403440444054406440744084409441044114412441344144415441644174418441944204421442244234424442544264427442844294430443144324433443444354436443744384439444044414442444344444445444644474448444944504451445244534454445544564457445844594460446144624463446444654466446744684469447044714472447344744475447644774478447944804481448244834484448544864487448844894490449144924493449444954496449744984499450045014502450345044505450645074508450945104511451245134514451545164517451845194520452145224523452445254526452745284529453045314532453345344535453645374538453945404541454245434544454545464547454845494550455145524553455445554556455745584559456045614562456345644565456645674568456945704571457245734574457545764577457845794580458145824583458445854586458745884589459045914592459345944595459645974598459946004601460246034604460546064607460846094610461146124613461446154616461746184619462046214622462346244625462646274628462946304631463246334634463546364637463846394640464146424643464446454646464746484649465046514652465346544655465646574658465946604661466246634664466546664667466846694670467146724673467446754676467746784679468046814682468346844685468646874688468946904691469246934694469546964697469846994700470147024703470447054706470747084709471047114712471347144715471647174718471947204721472247234724472547264727472847294730473147324733473447354736473747384739474047414742474347444745474647474748474947504751475247534754475547564757475847594760476147624763476447654766476747684769477047714772477347744775477647774778477947804781478247834784478547864787478847894790479147924793479447954796479747984799480048014802480348044805480648074808480948104811481248134814481548164817481848194820482148224823482448254826482748284829483048314832483348344835483648374838483948404841484248434844484548464847484848494850485148524853485448554856485748584859486048614862486348644865486648674868486948704871487248734874487548764877487848794880488148824883488448854886488748884889489048914892489348944895489648974898489949004901490249034904490549064907490849094910491149124913491449154916491749184919492049214922492349244925492649274928492949304931493249334934493549364937493849394940494149424943494449454946494749484949495049514952495349544955495649574958495949604961496249634964496549664967496849694970497149724973497449754976497749784979498049814982498349844985498649874988498949904991499249934994499549964997499849995000500150025003500450055006500750085009501050115012501350145015501650175018501950205021502250235024502550265027502850295030503150325033503450355036503750385039504050415042504350445045504650475048504950505051505250535054505550565057505850595060506150625063506450655066506750685069507050715072507350745075507650775078507950805081508250835084508550865087508850895090509150925093509450955096509750985099510051015102510351045105510651075108510951105111511251135114511551165117511851195120512151225123512451255126512751285129513051315132513351345135513651375138513951405141514251435144514551465147514851495150515151525153515451555156515751585159516051615162516351645165516651675168516951705171517251735174517551765177517851795180518151825183518451855186518751885189519051915192519351945195519651975198519952005201520252035204520552065207520852095210521152125213521452155216521752185219522052215222522352245225522652275228522952305231523252335234523552365237523852395240524152425243524452455246524752485249525052515252525352545255525652575258525952605261526252635264526552665267526852695270527152725273527452755276527752785279528052815282528352845285528652875288528952905291529252935294529552965297529852995300530153025303530453055306530753085309531053115312531353145315531653175318531953205321532253235324532553265327532853295330533153325333533453355336533753385339534053415342534353445345534653475348534953505351535253535354535553565357535853595360536153625363536453655366536753685369537053715372537353745375537653775378537953805381538253835384538553865387538853895390539153925393539453955396539753985399540054015402540354045405540654075408540954105411541254135414541554165417541854195420542154225423542454255426542754285429543054315432543354345435543654375438543954405441544254435444544554465447544854495450545154525453545454555456545754585459546054615462546354645465546654675468546954705471547254735474547554765477547854795480548154825483548454855486548754885489549054915492549354945495549654975498549955005501550255035504550555065507550855095510551155125513551455155516551755185519552055215522552355245525552655275528552955305531553255335534553555365537553855395540554155425543554455455546554755485549555055515552555355545555555655575558555955605561556255635564556555665567556855695570557155725573557455755576557755785579558055815582558355845585558655875588558955905591559255935594559555965597559855995600560156025603560456055606560756085609561056115612561356145615561656175618561956205621562256235624562556265627562856295630563156325633563456355636563756385639564056415642564356445645564656475648564956505651565256535654565556565657565856595660566156625663566456655666566756685669567056715672567356745675567656775678567956805681568256835684568556865687568856895690569156925693569456955696569756985699570057015702570357045705570657075708570957105711571257135714571557165717571857195720572157225723572457255726572757285729573057315732573357345735573657375738573957405741574257435744574557465747574857495750575157525753575457555756575757585759576057615762576357645765576657675768576957705771577257735774577557765777577857795780578157825783578457855786578757885789579057915792579357945795579657975798579958005801580258035804580558065807580858095810581158125813581458155816581758185819582058215822582358245825582658275828582958305831583258335834583558365837583858395840584158425843584458455846584758485849585058515852585358545855585658575858585958605861586258635864586558665867586858695870587158725873587458755876587758785879588058815882588358845885588658875888588958905891589258935894589558965897589858995900590159025903590459055906590759085909591059115912591359145915591659175918591959205921592259235924592559265927592859295930593159325933593459355936593759385939594059415942594359445945594659475948594959505951595259535954595559565957595859595960596159625963596459655966596759685969597059715972597359745975597659775978597959805981598259835984598559865987598859895990599159925993599459955996599759985999600060016002600360046005600660076008600960106011601260136014601560166017601860196020602160226023602460256026602760286029603060316032603360346035603660376038603960406041604260436044604560466047604860496050605160526053605460556056605760586059606060616062606360646065606660676068606960706071607260736074607560766077607860796080608160826083608460856086608760886089609060916092609360946095609660976098609961006101610261036104610561066107610861096110611161126113611461156116611761186119612061216122612361246125612661276128612961306131613261336134613561366137613861396140614161426143614461456146614761486149615061516152615361546155615661576158615961606161616261636164616561666167616861696170617161726173617461756176617761786179618061816182618361846185618661876188618961906191619261936194619561966197619861996200620162026203620462056206620762086209621062116212621362146215621662176218621962206221622262236224622562266227622862296230623162326233623462356236623762386239624062416242624362446245624662476248624962506251625262536254625562566257625862596260626162626263626462656266626762686269627062716272627362746275627662776278627962806281628262836284628562866287628862896290629162926293629462956296629762986299630063016302630363046305630663076308630963106311631263136314631563166317631863196320632163226323632463256326632763286329633063316332633363346335633663376338633963406341634263436344634563466347634863496350635163526353635463556356635763586359636063616362636363646365636663676368636963706371637263736374637563766377637863796380638163826383638463856386638763886389639063916392639363946395639663976398639964006401640264036404640564066407640864096410641164126413641464156416641764186419642064216422642364246425642664276428642964306431643264336434643564366437643864396440644164426443644464456446644764486449645064516452645364546455645664576458645964606461646264636464646564666467646864696470647164726473647464756476647764786479648064816482648364846485648664876488648964906491649264936494649564966497649864996500650165026503650465056506650765086509651065116512651365146515651665176518651965206521652265236524652565266527652865296530653165326533653465356536653765386539654065416542654365446545654665476548654965506551655265536554655565566557655865596560656165626563656465656566656765686569657065716572657365746575657665776578657965806581658265836584658565866587658865896590659165926593659465956596659765986599660066016602660366046605660666076608660966106611661266136614661566166617661866196620662166226623662466256626662766286629663066316632663366346635663666376638663966406641664266436644664566466647664866496650665166526653665466556656665766586659666066616662666366646665666666676668666966706671667266736674667566766677667866796680668166826683668466856686668766886689669066916692669366946695669666976698669967006701670267036704670567066707670867096710671167126713671467156716671767186719672067216722672367246725672667276728672967306731673267336734673567366737673867396740674167426743674467456746674767486749675067516752675367546755675667576758675967606761676267636764676567666767676867696770677167726773677467756776677767786779678067816782678367846785678667876788678967906791679267936794679567966797679867996800680168026803680468056806680768086809681068116812681368146815681668176818681968206821682268236824682568266827682868296830683168326833683468356836683768386839684068416842684368446845684668476848684968506851685268536854685568566857685868596860686168626863686468656866686768686869687068716872687368746875687668776878687968806881688268836884688568866887688868896890689168926893689468956896689768986899690069016902690369046905690669076908690969106911691269136914691569166917691869196920692169226923692469256926692769286929693069316932693369346935693669376938693969406941694269436944694569466947694869496950695169526953695469556956695769586959696069616962696369646965696669676968696969706971697269736974697569766977697869796980698169826983698469856986698769886989699069916992699369946995699669976998699970007001700270037004700570067007700870097010701170127013701470157016701770187019702070217022702370247025702670277028702970307031703270337034703570367037703870397040704170427043704470457046704770487049705070517052705370547055705670577058705970607061706270637064706570667067706870697070707170727073707470757076707770787079708070817082708370847085708670877088708970907091709270937094709570967097709870997100710171027103710471057106710771087109711071117112711371147115711671177118711971207121712271237124712571267127712871297130713171327133713471357136713771387139714071417142714371447145714671477148714971507151715271537154715571567157715871597160716171627163716471657166716771687169717071717172717371747175717671777178717971807181718271837184718571867187718871897190719171927193719471957196719771987199720072017202720372047205720672077208720972107211721272137214721572167217721872197220722172227223722472257226722772287229723072317232723372347235723672377238723972407241724272437244724572467247724872497250725172527253725472557256725772587259726072617262726372647265726672677268726972707271727272737274727572767277727872797280728172827283728472857286728772887289729072917292729372947295729672977298729973007301730273037304730573067307730873097310731173127313731473157316731773187319732073217322732373247325732673277328732973307331733273337334733573367337733873397340734173427343734473457346734773487349735073517352735373547355735673577358735973607361736273637364736573667367736873697370737173727373737473757376737773787379738073817382738373847385738673877388738973907391739273937394739573967397739873997400740174027403740474057406740774087409741074117412741374147415741674177418741974207421742274237424742574267427742874297430743174327433743474357436743774387439744074417442744374447445744674477448744974507451745274537454745574567457745874597460746174627463746474657466746774687469747074717472747374747475747674777478747974807481748274837484748574867487748874897490749174927493749474957496749774987499750075017502750375047505750675077508750975107511751275137514751575167517751875197520752175227523752475257526752775287529753075317532753375347535753675377538753975407541754275437544754575467547754875497550755175527553755475557556755775587559756075617562756375647565756675677568756975707571757275737574757575767577757875797580758175827583758475857586758775887589759075917592759375947595759675977598759976007601760276037604760576067607760876097610761176127613761476157616761776187619762076217622762376247625762676277628762976307631763276337634763576367637763876397640764176427643764476457646764776487649765076517652765376547655765676577658765976607661766276637664766576667667766876697670767176727673767476757676767776787679768076817682768376847685768676877688768976907691769276937694769576967697769876997700770177027703770477057706770777087709771077117712771377147715771677177718771977207721772277237724772577267727772877297730773177327733773477357736773777387739774077417742774377447745774677477748774977507751775277537754775577567757775877597760776177627763776477657766776777687769777077717772777377747775777677777778777977807781778277837784778577867787778877897790779177927793779477957796779777987799780078017802780378047805780678077808780978107811781278137814781578167817781878197820782178227823782478257826782778287829783078317832783378347835783678377838783978407841784278437844784578467847784878497850785178527853785478557856785778587859786078617862786378647865786678677868786978707871787278737874787578767877787878797880788178827883788478857886788778887889789078917892789378947895789678977898789979007901790279037904790579067907790879097910791179127913791479157916791779187919792079217922792379247925792679277928792979307931793279337934793579367937793879397940794179427943794479457946794779487949795079517952795379547955795679577958795979607961796279637964796579667967796879697970797179727973797479757976797779787979798079817982798379847985798679877988798979907991799279937994799579967997799879998000800180028003800480058006800780088009801080118012801380148015801680178018801980208021802280238024802580268027802880298030803180328033803480358036803780388039804080418042804380448045804680478048804980508051805280538054805580568057805880598060806180628063806480658066806780688069807080718072807380748075807680778078807980808081808280838084808580868087808880898090809180928093809480958096809780988099810081018102810381048105810681078108810981108111811281138114811581168117811881198120812181228123812481258126812781288129813081318132813381348135813681378138813981408141814281438144814581468147814881498150815181528153815481558156815781588159816081618162816381648165816681678168816981708171817281738174817581768177817881798180818181828183818481858186818781888189819081918192819381948195819681978198819982008201820282038204820582068207820882098210821182128213821482158216821782188219822082218222822382248225822682278228822982308231823282338234823582368237823882398240824182428243824482458246824782488249825082518252825382548255825682578258825982608261826282638264826582668267826882698270827182728273827482758276827782788279828082818282828382848285828682878288828982908291829282938294829582968297829882998300830183028303830483058306830783088309831083118312831383148315831683178318831983208321832283238324832583268327832883298330833183328333833483358336833783388339834083418342834383448345834683478348834983508351835283538354835583568357835883598360836183628363836483658366836783688369837083718372837383748375837683778378837983808381838283838384838583868387838883898390839183928393839483958396839783988399840084018402840384048405840684078408840984108411841284138414841584168417841884198420842184228423842484258426842784288429843084318432843384348435843684378438843984408441844284438444844584468447844884498450845184528453845484558456
  1. # mypy: allow-untyped-defs
  2. import math
  3. from collections.abc import Callable, Sequence
  4. from enum import Enum
  5. from functools import wraps
  6. from typing import TypeVar
  7. from typing_extensions import ParamSpec
  8. import torch
  9. import torch._prims_common as utils
  10. from torch import SymBool, SymFloat, Tensor
  11. from torch._decomp import (
  12. _add_op_to_registry,
  13. _convert_out_params,
  14. global_decomposition_table,
  15. meta_table,
  16. )
  17. from torch._ops import OpOverload
  18. from torch._prims import _prim_elementwise_meta, ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND
  19. from torch._prims_common import (
  20. BoolLike,
  21. corresponding_complex_dtype,
  22. corresponding_real_dtype,
  23. elementwise_dtypes,
  24. ELEMENTWISE_TYPE_PROMOTION_KIND,
  25. FloatLike,
  26. IntLike,
  27. make_contiguous_strides_for,
  28. Number,
  29. suggest_memory_format,
  30. TensorLike,
  31. )
  32. from torch._prims_common.wrappers import (
  33. _maybe_convert_to_dtype,
  34. _maybe_resize_out,
  35. _resize_output_check,
  36. _safe_copy_out,
  37. out_wrapper,
  38. )
  39. from torch._refs import _broadcast_shapes, _maybe_broadcast
  40. from torch.fx.experimental import _config as exp_config
  41. from torch.nn.functional import ScalingType, SwizzleType
  42. from torch.utils import _pytree as pytree
  43. _T = TypeVar("_T")
  44. _P = ParamSpec("_P")
  45. aten = torch.ops.aten
  46. _meta_lib_dont_use_me_use_register_meta = torch.library.Library("aten", "IMPL", "Meta")
  47. MODE_SUM, MODE_MEAN, MODE_MAX = range(3)
  48. def ceil_div(a, b):
  49. return (a + b - 1) // b
  50. def round_up(x, y):
  51. """Rounds up x to nearest multiple of y"""
  52. return ((x + y - 1) // y) * y
  53. def register_meta(op) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  54. def wrapper(fn):
  55. fn = _convert_out_params(fn)
  56. def register(op):
  57. _add_op_to_registry(meta_table, op, fn)
  58. pytree.tree_map_(register, op)
  59. return fn
  60. return wrapper
  61. def elementwise_meta(
  62. *args,
  63. type_promotion: ELEMENTWISE_TYPE_PROMOTION_KIND,
  64. ):
  65. # Perform type promotion, as this is expected from prim_metafunction
  66. _, result_dtype = utils.elementwise_dtypes(
  67. *args,
  68. type_promotion_kind=type_promotion,
  69. )
  70. args = [_maybe_convert_to_dtype(x, result_dtype) for x in args]
  71. # Broadcast
  72. args = _maybe_broadcast(*args)
  73. # Perform prim checks
  74. return _prim_elementwise_meta(
  75. *args, type_promotion=ELEMENTWISE_PRIM_TYPE_PROMOTION_KIND.DEFAULT
  76. )
  77. def toRealValueType(dtype):
  78. from_complex = {
  79. torch.complex32: torch.half,
  80. torch.cfloat: torch.float,
  81. torch.cdouble: torch.double,
  82. }
  83. return from_complex.get(dtype, dtype)
  84. def check_inplace_broadcast(self_shape, *args_shape):
  85. broadcasted_shape = tuple(_broadcast_shapes(self_shape, *args_shape))
  86. torch._check(
  87. broadcasted_shape == self_shape,
  88. lambda: f"output with shape {self_shape} doesn't match the broadcast shape {broadcasted_shape}",
  89. )
  90. @register_meta([aten.linspace, aten.logspace])
  91. @out_wrapper()
  92. def meta_linspace_logspace(
  93. start,
  94. end,
  95. steps,
  96. base=None,
  97. dtype=None,
  98. device=None,
  99. layout=torch.strided,
  100. pin_memory=False,
  101. requires_grad=False,
  102. ):
  103. if isinstance(start, torch.Tensor):
  104. torch._check(
  105. start.dim() == 0,
  106. lambda: "linspace only supports 0-dimensional start and end tensors",
  107. )
  108. if isinstance(end, torch.Tensor):
  109. torch._check(
  110. end.dim() == 0,
  111. lambda: "linspace only supports 0-dimensional start and end tensors",
  112. )
  113. if any(isinstance(arg, complex) for arg in (start, end, steps)):
  114. default_complex_dtype = utils.corresponding_complex_dtype(
  115. torch.get_default_dtype()
  116. )
  117. if dtype is None:
  118. dtype = default_complex_dtype
  119. else:
  120. torch._check(
  121. utils.is_complex_dtype(dtype),
  122. lambda: f"linspace(): inferred dtype {default_complex_dtype} can't be safely cast to passed dtype {dtype}",
  123. )
  124. else:
  125. dtype = dtype or torch.get_default_dtype()
  126. assert isinstance(dtype, torch.dtype)
  127. # steps does not participate in the computation of the dtype
  128. torch._check_type(
  129. isinstance(steps, IntLike),
  130. lambda: f"received an invalid combination of arguments - got \
  131. ({type(start).__name__}, {type(end).__name__}, {type(steps).__name__})",
  132. )
  133. assert isinstance(steps, IntLike) # for mypy
  134. torch._check(steps >= 0, lambda: "number of steps must be non-negative")
  135. return torch.empty(
  136. (steps,), # type: ignore[arg-type]
  137. dtype=dtype,
  138. layout=layout,
  139. device="meta",
  140. pin_memory=pin_memory,
  141. requires_grad=requires_grad,
  142. )
  143. @register_meta([aten.take.default, aten.take.out])
  144. @out_wrapper()
  145. def meta_take(self, index):
  146. # Type and device checks
  147. torch._check(
  148. index.dtype == torch.long,
  149. lambda: f"take(): Expected a long tensor for index, but got {index.dtype}",
  150. )
  151. # Index checks
  152. torch._check_index(
  153. not (self.numel() == 0 and index.numel() != 0),
  154. lambda: "take(): tried to take from an empty tensor",
  155. )
  156. return self.new_empty(index.shape)
  157. @register_meta([aten.linalg_cross.default, aten.linalg_cross.out])
  158. @out_wrapper()
  159. def linalg_cross(self, other, *, dim=-1):
  160. x_d = self.ndim
  161. y_d = other.ndim
  162. torch._check(
  163. x_d == y_d,
  164. lambda: "linalg.cross: inputs must have the same number of dimensions.",
  165. )
  166. torch._check(
  167. self.size(dim) == 3 and other.size(dim) == 3,
  168. lambda: (
  169. f"linalg.cross: inputs dimension {dim} must have length 3. "
  170. f"Got {self.size(dim)} and {other.size(dim)}"
  171. ),
  172. )
  173. out_shape = _broadcast_shapes(self.shape, other.shape)
  174. return self.new_empty(out_shape)
  175. @register_meta(aten.linalg_matrix_exp)
  176. @out_wrapper()
  177. def linalg_matrix_exp(self):
  178. squareCheckInputs(self, "linalg.matrix_exp")
  179. checkFloatingOrComplex(self, "linalg.matrix_exp")
  180. return torch.empty_like(self, memory_format=torch.contiguous_format)
  181. @register_meta(
  182. [aten.cummax.default, aten.cummax.out, aten.cummin.default, aten.cummin.out]
  183. )
  184. @out_wrapper("values", "indices")
  185. def cummaxmin(self, dim):
  186. values = torch.empty(self.shape, device=self.device, dtype=self.dtype)
  187. indices = torch.empty(self.shape, device=self.device, dtype=torch.int64)
  188. if self.numel() != 0 and self.ndim != 0:
  189. # Checks that dim is within bounds
  190. maybe_wrap_dim(dim, self.ndim)
  191. return values, indices
  192. @register_meta([aten.logcumsumexp.default, aten.logcumsumexp.out])
  193. @out_wrapper()
  194. def logcumsumexp(self, dim):
  195. # Checks that dim is within bounds
  196. maybe_wrap_dim(dim, self.ndim)
  197. return torch.empty_like(self, memory_format=torch.contiguous_format)
  198. # Stride-related code from _exec_fft in aten/src/ATen/native/mkl/SpectralOps.cpp
  199. # and aten/src/ATen/cuda/SpectralOps.cpp
  200. #
  201. # Although the actual FFT launch is different, all the permuting code appears
  202. # to be the same
  203. def _exec_fft(out, self, out_sizes, dim, *, forward):
  204. ndim = self.ndim
  205. signal_ndim = len(dim)
  206. batch_dims = ndim - signal_ndim
  207. # Permute dimensions so batch dimensions come first, and in stride order
  208. dim_permute = list(range(ndim))
  209. is_transformed_dim = [False for _ in range(ndim)]
  210. for d in dim:
  211. is_transformed_dim[d] = True
  212. # std::partition
  213. left, right = [], []
  214. for d in dim_permute:
  215. if not is_transformed_dim[d]:
  216. left.append(d)
  217. else:
  218. right.append(d)
  219. dim_permute = left + right
  220. batch_end = len(left)
  221. self_strides = self.stride()
  222. tmp = dim_permute[:batch_end]
  223. tmp.sort(key=lambda x: self_strides[x], reverse=True)
  224. dim_permute = tmp + dim_permute[batch_end:]
  225. input = self.permute(dim_permute)
  226. # Collapse batch dimensions into a single dimension
  227. batched_sizes = [-1] + list(input.shape[batch_dims:])
  228. input = input.reshape(batched_sizes)
  229. batch_size = input.size(0)
  230. batched_sizes[0] = batch_size
  231. batched_out_sizes = list(batched_sizes)
  232. for i in range(len(dim)):
  233. batched_out_sizes[i + 1] = out_sizes[dim[i]]
  234. out.resize_(batched_out_sizes, memory_format=torch.contiguous_format)
  235. # Inplace reshaping to original batch shape and inverting the dimension permutation
  236. out_strides = [0 for _ in range(ndim)]
  237. batch_numel = 1
  238. i = batch_dims - 1
  239. while i >= 0:
  240. out_strides[dim_permute[i]] = batch_numel * out.stride(0)
  241. batch_numel *= out_sizes[dim_permute[i]]
  242. i -= 1
  243. for i in range(batch_dims, ndim):
  244. out_strides[dim_permute[i]] = out.stride(1 + (i - batch_dims))
  245. out.as_strided_(out_sizes, out_strides, out.storage_offset())
  246. return out
  247. def _sort_dims(self: Tensor, dim: list[int], exclude_last: bool = False):
  248. sorted_dims = list(dim)
  249. self_strides = self.stride()
  250. sorted_dims[: len(sorted_dims) - int(exclude_last)].sort(
  251. key=lambda i: self_strides[i]
  252. )
  253. return sorted_dims
  254. # See _fft_c2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
  255. # and _fft_c2c_mkl in aten/src/ATen/native/mkl/SpectralOps.cpp
  256. @register_meta([aten._fft_c2c.default, aten._fft_c2c.out])
  257. @out_wrapper()
  258. def meta_fft_c2c(self, dim, normalization, forward):
  259. torch._check(self.dtype.is_complex)
  260. if not dim:
  261. return self.clone()
  262. sorted_dims = _sort_dims(self, dim)
  263. out = self.new_empty(self.size())
  264. return _exec_fft(out, self, self.size(), sorted_dims, forward=forward)
  265. cufft_max_ndim = 3
  266. def use_optimized_cufft_path(dim: list[int]):
  267. if len(dim) > cufft_max_ndim or (len(dim) >= 2 and dim[0] == 0 and dim[1] == 1):
  268. return False
  269. else:
  270. return True
  271. @register_meta([aten._fft_r2c.default, aten._fft_r2c.out])
  272. @out_wrapper()
  273. def meta_fft_r2c(self, dim, normalization, onesided):
  274. torch._check(self.dtype.is_floating_point)
  275. input_sizes = list(self.size())
  276. out_sizes = list(input_sizes)
  277. last_dim = dim[-1]
  278. last_dim_halfsize = input_sizes[last_dim] // 2 + 1
  279. onesided_sizes = list(input_sizes)
  280. onesided_sizes[last_dim] = last_dim_halfsize
  281. if onesided:
  282. out_sizes[last_dim] = last_dim_halfsize
  283. if device_hint(self) == "cuda" or device_hint(self) == "xpu":
  284. # _fft_r2c_cufft in aten/src/ATen/native/cuda/SpectralOps.cpp
  285. # _fft_r2c_xpu in torch-xpu-ops/src/ATen/native/xpu/SpectralOps.cpp
  286. output = self.new_empty(
  287. out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
  288. )
  289. working_tensor = self
  290. if device_hint(self) == "cuda" and use_optimized_cufft_path(dim):
  291. _exec_fft(output, working_tensor, out_sizes, dim, forward=True)
  292. else:
  293. # First do the R2C transform on the last dimension
  294. target_sizes = out_sizes if len(dim) == 1 else onesided_sizes
  295. _exec_fft(output, working_tensor, target_sizes, [last_dim], forward=True)
  296. if len(dim) > 1:
  297. working_tensor = self.new_empty(
  298. out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
  299. )
  300. # Then any remaining C2C transforms
  301. sorted_dims = dim[:-1]
  302. while sorted_dims:
  303. output, working_tensor = working_tensor, output
  304. strides = working_tensor.stride()
  305. sorted_dims.sort(
  306. key=lambda i: strides[i], reverse=True
  307. ) # NB reverse! Not sure if this is og bug
  308. max_dims = min(cufft_max_ndim, len(sorted_dims))
  309. last_dims = sorted_dims[len(sorted_dims) - max_dims :]
  310. _exec_fft(
  311. output, working_tensor, onesided_sizes, last_dims, forward=True
  312. )
  313. sorted_dims = sorted_dims[: len(sorted_dims) - max_dims]
  314. if not onesided:
  315. if output.size(last_dim) != out_sizes[last_dim]:
  316. working_tensor.resize_(out_sizes, memory_format=torch.contiguous_format)
  317. output = working_tensor
  318. return output
  319. else:
  320. return self.new_empty(
  321. out_sizes, dtype=utils.corresponding_complex_dtype(self.dtype)
  322. )
  323. @register_meta(aten.randperm.generator_out)
  324. def meta_randperm(n, *, generator=None, out):
  325. return _maybe_resize_out(out, torch.Size([n]))
  326. @register_meta(aten.randperm.default)
  327. def meta_randperm_default(
  328. n,
  329. *,
  330. dtype=torch.long,
  331. layout=None,
  332. device=None,
  333. pin_memory=None,
  334. ):
  335. return torch.empty(
  336. n, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  337. )
  338. @register_meta([aten.randint.default, aten.randint.out])
  339. @out_wrapper()
  340. def meta_randint(
  341. high,
  342. size,
  343. *,
  344. dtype=torch.long,
  345. layout=None,
  346. device=None,
  347. pin_memory=None,
  348. ):
  349. low = 0
  350. torch._check(
  351. high > low,
  352. lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}",
  353. )
  354. return torch.empty(
  355. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  356. )
  357. @register_meta([aten.randint.low, aten.randint.low_out])
  358. @out_wrapper()
  359. def meta_randint_low(
  360. low,
  361. high,
  362. size,
  363. *,
  364. dtype=torch.long,
  365. layout=None,
  366. device=None,
  367. pin_memory=None,
  368. ):
  369. torch._check(
  370. high > low,
  371. lambda: f"random_ expects 'from' to be less than 'to', but got from={low} >= to={high}",
  372. )
  373. return torch.empty(
  374. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  375. )
  376. @register_meta([aten.rand.default, aten.rand.out])
  377. @out_wrapper()
  378. def meta_rand_default(size, *, dtype=None, layout=None, device=None, pin_memory=None):
  379. return torch.empty(
  380. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  381. )
  382. @register_meta([aten._fft_c2r.default, aten._fft_c2r.out])
  383. @out_wrapper()
  384. def meta_fft_c2r(self: Tensor, dim: list[int], normalization: int, lastdim: int):
  385. # _fft_c2r_mkl
  386. torch._check(self.dtype.is_complex)
  387. if device_hint(self) == "cuda":
  388. out_sizes = list(self.size())
  389. out_sizes[dim[-1]] = lastdim
  390. output = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
  391. if use_optimized_cufft_path(dim):
  392. return _exec_fft(
  393. output,
  394. self.clone(memory_format=torch.contiguous_format),
  395. out_sizes,
  396. dim,
  397. forward=False,
  398. )
  399. else:
  400. # First complete any C2C transforms
  401. if len(dim) > 1:
  402. temp = meta_fft_c2c(self, dim[:-1], 0, lastdim) # fft_norm_mode::none
  403. else:
  404. temp = self.clone(memory_format=torch.contiguous_format)
  405. return _exec_fft(output, temp, out_sizes, [dim[-1]], forward=False)
  406. else:
  407. input = self
  408. if len(dim) > 1:
  409. c2c_dims = dim[:-1]
  410. input = meta_fft_c2c(self, c2c_dims, normalization, forward=False)
  411. dim = dim[-1:]
  412. out_sizes = list(input.size())
  413. out_sizes[dim[-1]] = lastdim
  414. out = self.new_empty(out_sizes, dtype=toRealValueType(self.dtype))
  415. return _exec_fft(out, input, out_sizes, dim, forward=False)
  416. @register_meta(aten.copy_.default)
  417. def meta_copy_(self, src, non_blocking=False):
  418. # This code simulates the original decomp from inductor,
  419. # which runs most of the meta checks that we care about.
  420. # In theory, we should make this more robust by carefully
  421. # auditing our C++ copy_() kernel and copying the checks here.
  422. from torch.fx.experimental.symbolic_shapes import free_unbacked_symbols
  423. # TODO: Ideally, we'd insert a deferred runtime assert here, but if we are
  424. # calling an actual copy_, you'll get that automatically
  425. # https://github.com/pytorch/pytorch/issues/122477
  426. if (
  427. not free_unbacked_symbols(self) and torch._debug_has_internal_overlap(self) == 1
  428. ): # 1 == MemOverlap::Yes
  429. raise RuntimeError(
  430. "more than one element of the written-to tensor refers to a single memory location"
  431. )
  432. if isinstance(src, Tensor):
  433. intermediate = src.to(self, non_blocking)
  434. if self.size() != intermediate.size():
  435. aten.expand_copy.default(intermediate, self.size())
  436. return self
  437. def inferUnsqueezeGeometry(tensor, dim):
  438. result_sizes = list(tensor.size())
  439. result_strides = list(tensor.stride())
  440. # pyrefly: ignore [unsupported-operation]
  441. new_stride = 1 if dim >= tensor.dim() else result_sizes[dim] * result_strides[dim]
  442. # pyrefly: ignore [bad-argument-type]
  443. result_sizes.insert(dim, 1)
  444. # pyrefly: ignore [bad-argument-type]
  445. result_strides.insert(dim, new_stride)
  446. return result_sizes, result_strides
  447. @register_meta(aten.unsqueeze_.default)
  448. def meta_unsqueeze_(self, dim):
  449. dim = maybe_wrap_dim(dim, self.dim() + 1)
  450. g_sizes, g_strides = inferUnsqueezeGeometry(self, dim)
  451. self.as_strided_(g_sizes, g_strides)
  452. return self
  453. @register_meta(aten._sparse_semi_structured_linear)
  454. def meta_sparse_structured_linear(
  455. input: Tensor,
  456. weight: Tensor,
  457. _meta: Tensor,
  458. bias: Tensor | None = None,
  459. _activation_opt: str | None = None,
  460. out_dtype: torch.dtype | None = None,
  461. ):
  462. output_sizes = list(input.shape)
  463. if bias is not None:
  464. assert weight.size(0) == bias.size(0), "output size mismatch"
  465. assert weight.size(1) == input.size(-1) / 2
  466. output_sizes[-1] = weight.size(0)
  467. # see: https://github.com/pytorch/pytorch/pull/114477#issuecomment-1830121375
  468. # We assume that we have already squashed the inputs into a 2-D tensor
  469. # Then, as the output is transposed, we need to propagate the transposed
  470. # stride information to the output tensor
  471. assert len(input.shape) == 2, "we can only handle the squashed input case"
  472. transposed_strides = (1, input.size(0))
  473. if out_dtype is not None:
  474. assert input.dtype == torch.int8 and out_dtype == torch.int32, (
  475. "out_dtype is only supported for i8i8->i32 linear operator"
  476. )
  477. output = input.new_empty(
  478. output_sizes,
  479. dtype=input.dtype if out_dtype is None else out_dtype,
  480. ).as_strided(output_sizes, transposed_strides)
  481. return output
  482. @register_meta(aten._sparse_semi_structured_mm)
  483. def meta_sparse_structured_mm(
  484. mat1: Tensor,
  485. mat1_meta: Tensor,
  486. mat2: Tensor,
  487. out_dtype: torch.dtype | None = None,
  488. ):
  489. assert len(mat1.shape) == 2
  490. assert len(mat1_meta.shape) == 2
  491. assert len(mat2.shape) == 2
  492. assert mat1.size(1) == mat2.size(0) / 2
  493. output_sizes = [mat1.size(0), mat2.size(1)]
  494. if out_dtype is not None:
  495. assert mat2.dtype == torch.int8 and out_dtype == torch.int32, (
  496. "out_dtype is only supported for i8i8->i32 linear operator"
  497. )
  498. output = mat2.new_empty(
  499. output_sizes,
  500. dtype=mat2.dtype if out_dtype is None else out_dtype,
  501. )
  502. return output
  503. @register_meta(aten._sparse_semi_structured_addmm)
  504. def meta_sparse_structured_addmm(
  505. input: Tensor,
  506. mat1: Tensor,
  507. mat1_meta: Tensor,
  508. mat2: Tensor,
  509. *,
  510. alpha=1,
  511. beta=1,
  512. out_dtype: torch.dtype | None = None,
  513. ):
  514. assert len(input.shape) == 1, (
  515. "only input broadcasted to columns of mat1 * mat2 product is supported"
  516. )
  517. assert len(mat1.shape) == 2
  518. assert len(mat1_meta.shape) == 2
  519. assert len(mat2.shape) == 2
  520. assert input.size(0) == mat1.size(0), (
  521. "only input broadcasted to columns of mat1 * mat2 product is supported"
  522. )
  523. assert mat1.size(1) == mat2.size(0) / 2
  524. output_sizes = [mat1.size(0), mat2.size(1)]
  525. if out_dtype is not None:
  526. assert mat2.dtype == torch.int8 and out_dtype == torch.int32, (
  527. "out_dtype is only supported for i8i8->i32 linear operator"
  528. )
  529. output = mat2.new_empty(
  530. output_sizes,
  531. dtype=mat2.dtype if out_dtype is None else out_dtype,
  532. )
  533. return output
  534. @register_meta(aten._cslt_sparse_mm)
  535. def meta__cslt_sparse_mm(
  536. compressed_A: torch.Tensor,
  537. dense_B: torch.Tensor,
  538. bias: Tensor | None = None,
  539. alpha: Tensor | None = None,
  540. out_dtype: torch.dtype | None = None,
  541. transpose_result: bool = False,
  542. alg_id: int = 0,
  543. split_k: int = 1,
  544. split_k_mode: int = -1,
  545. ):
  546. assert dense_B.dtype in {
  547. torch.float32,
  548. torch.float16,
  549. torch.bfloat16,
  550. torch.int8,
  551. torch.float8_e4m3fn,
  552. }, "_cslt_sparse_mm only supports fp16, bf16, int8, and fp8e4m3"
  553. assert compressed_A.dtype == dense_B.dtype, "inputs must have the same dtype"
  554. assert len(dense_B.shape) == 2, "_cslt_sparse_mm only supports 2d inputs"
  555. is_8bit_input_type = compressed_A.dtype in [torch.int8, torch.float8_e4m3fn]
  556. if is_8bit_input_type:
  557. assert not dense_B.is_contiguous(), (
  558. "dense input must be transposed for 8bit dtypes"
  559. )
  560. n = dense_B.size(1)
  561. m = compressed_A.size(0)
  562. if bias is not None:
  563. assert m == bias.size(0)
  564. if out_dtype is not None:
  565. assert is_8bit_input_type and out_dtype in {
  566. torch.float16,
  567. torch.bfloat16,
  568. torch.int32,
  569. torch.float8_e4m3fn,
  570. }, (
  571. f"out_dtype is not supported for {compressed_A.dtype} x {dense_B.dtype} -> {out_dtype} matmul!"
  572. )
  573. output_shape = (n, m) if transpose_result else (m, n)
  574. return dense_B.new_empty(output_shape, dtype=out_dtype)
  575. @register_meta(aten.index_reduce.default)
  576. def meta_index_reduce(
  577. self: Tensor,
  578. dim: int,
  579. index: Tensor,
  580. source: torch.Tensor,
  581. reduce: str,
  582. *,
  583. include_self: bool = True,
  584. ) -> Tensor:
  585. return torch.empty_like(self, memory_format=torch.contiguous_format)
  586. @register_meta(aten.index_reduce_.default)
  587. def meta_index_reduce_(
  588. self: Tensor,
  589. dim: int,
  590. index: Tensor,
  591. source: torch.Tensor,
  592. reduce: str,
  593. *,
  594. include_self: bool = True,
  595. ) -> Tensor:
  596. return self
  597. # Implementations below are taken from https://github.com/albanD/subclass_zoo/blob/main/python_meta_tensor.py
  598. @out_wrapper()
  599. @register_meta(aten.index_select.default)
  600. def meta_index_select(self, dim, index):
  601. result_size = list(self.size())
  602. if self.dim() > 0:
  603. result_size[dim] = index.numel()
  604. return self.new_empty(result_size)
  605. @register_meta(aten.segment_reduce.default)
  606. def meta_segment_reduce(
  607. data: Tensor,
  608. reduce: str,
  609. *,
  610. lengths: Tensor | None = None,
  611. indices: Tensor | None = None,
  612. offsets: Tensor | None = None,
  613. axis: int = 0,
  614. unsafe: bool = False,
  615. initial=None,
  616. ) -> Tensor:
  617. if indices is not None:
  618. raise NotImplementedError(
  619. "segment_reduce(): indices based reduction is not supported yet."
  620. )
  621. def segment_reduce_lengths_tensor(lengths_shape):
  622. return torch.empty(
  623. lengths_shape + data.shape[axis + 1 :],
  624. dtype=data.dtype,
  625. device="meta",
  626. memory_format=torch.contiguous_format,
  627. )
  628. if lengths is not None:
  629. return segment_reduce_lengths_tensor(lengths.shape)
  630. # FIXME should probably check that lengths and offset aren't both set, but
  631. # the ATen implementation neglects this too
  632. if offsets is not None:
  633. # lengths == torch.diff(offsets)
  634. lengths_shape = offsets.shape[:-1] + (offsets.shape[-1] - 1,)
  635. return segment_reduce_lengths_tensor(lengths_shape)
  636. raise RuntimeError("segment_reduce(): Either lengths or offsets must be defined.")
  637. @register_meta([aten.max.default, aten.max.unary_out])
  638. @out_wrapper()
  639. def meta_max(self):
  640. return self.new_empty(())
  641. @register_meta(aten.max.dim)
  642. def meta_max_dim(self, dim, keepdim=False):
  643. dim = utils.reduction_dims(self.shape, (dim,))
  644. output_shape = _compute_reduction_shape(self, dim, keepdim)
  645. return (
  646. self.new_empty(output_shape),
  647. self.new_empty(output_shape, dtype=torch.long),
  648. )
  649. @register_meta([aten.min.default, aten.min.unary_out])
  650. @out_wrapper()
  651. def meta_min(self):
  652. return self.new_empty(())
  653. @register_meta(aten.min.dim)
  654. def meta_min_dim(self, dim, keepdim=False):
  655. dim = utils.reduction_dims(self.shape, (dim,))
  656. output_shape = _compute_reduction_shape(self, dim, keepdim)
  657. return (
  658. self.new_empty(output_shape),
  659. self.new_empty(output_shape, dtype=torch.long),
  660. )
  661. @register_meta(aten.angle.default)
  662. def meta_angle(self):
  663. if self.is_complex():
  664. result_dtype = corresponding_real_dtype(self.dtype)
  665. else:
  666. _, result_dtype = elementwise_dtypes(
  667. self,
  668. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  669. )
  670. return torch.empty_like(self, dtype=result_dtype)
  671. @register_meta(aten.angle.out)
  672. def meta_angle_out(self, out):
  673. torch._resize_output_(out, self.size(), self.device)
  674. return out.copy_(torch.angle(self))
  675. @register_meta(aten._assert_async.default)
  676. def assert_async(val):
  677. return
  678. @register_meta(aten._assert_async.msg)
  679. def assert_async_meta(val, assert_msg):
  680. return
  681. @register_meta(aten._print.default)
  682. def print_meta(s):
  683. return
  684. @register_meta(aten._make_dep_token.default)
  685. def make_dep_token(
  686. *,
  687. dtype=None,
  688. layout=None,
  689. device=None,
  690. pin_memory=None,
  691. memory_format=None,
  692. ):
  693. return torch.empty(0, device="meta")
  694. @register_meta(aten.sym_constrain_range.default)
  695. def sym_constrain_range(size, min=None, max=None):
  696. # Avoid importing sympy at a module level
  697. from torch.fx.experimental.symbolic_shapes import constrain_range
  698. if isinstance(size, (SymFloat, SymBool)):
  699. raise ValueError("Constraining SymFloat or Symbool is nyi")
  700. constrain_range(size, min=min, max=max)
  701. @register_meta(aten._functional_sym_constrain_range.default)
  702. def functional_sym_constrain_range(size, min=None, max=None, dep_token=None):
  703. aten.sym_constrain_range(size, min=min, max=max)
  704. return dep_token
  705. @register_meta(aten.sym_constrain_range_for_size.default)
  706. def sym_constrain_range_for_size(size, min=None, max=None):
  707. # Avoid importing sympy at a module level
  708. from torch.fx.experimental.symbolic_shapes import _constrain_range_for_size
  709. if min is None and max is None:
  710. torch._check(size >= 0)
  711. return
  712. if isinstance(size, (SymFloat, SymBool)):
  713. raise ValueError("Constraining SymFloat or Symbool is nyi")
  714. if type(size) is int:
  715. if min is not None:
  716. torch._check(size >= min)
  717. if max is not None:
  718. torch._check(size <= max)
  719. return
  720. _constrain_range_for_size(size, min=min, max=max)
  721. @register_meta(aten._functional_sym_constrain_range_for_size.default)
  722. def functional_sym_constrain_range_for_size(size, min, max, dep_token):
  723. aten.sym_constrain_range_for_size(size, min=min, max=max)
  724. return dep_token
  725. @register_meta(aten._functional_assert_async.msg)
  726. def functional_assert_async_meta(val, assert_msg, dep_token):
  727. return dep_token
  728. # From aten/src/ATen/native/LinearAlgebraUtils.h
  729. def squareCheckInputs(self: Tensor, f_name: str):
  730. assert self.dim() >= 2, (
  731. f"{f_name}: The input tensor must have at least 2 dimensions."
  732. )
  733. assert self.size(-1) == self.size(-2), (
  734. f"{f_name}: A must be batches of square matrices, but they are {self.size(-2)} by {self.size(-1)} matrices"
  735. )
  736. # Validates input shapes and devices
  737. # for linear solve methods (solve, cholesky_solve, lu_solve, triangular_solve)
  738. # From aten/src/ATen/native/LinearAlgebraUtils.h
  739. def linearSolveCheckInputs(self: Tensor, A: Tensor, name: str):
  740. torch._check(
  741. self.device == A.device,
  742. lambda: (
  743. f"Expected b and A to be on the same device, but found b on "
  744. f"{self.device} and A on {A.device} instead."
  745. ),
  746. )
  747. torch._check(
  748. self.dtype == A.dtype,
  749. lambda: (
  750. f"Expected b and A to have the same dtype, but found b of type "
  751. f"{self.dtype} and A of type {A.dtype} instead."
  752. ),
  753. )
  754. torch._check(
  755. A.size(-1) == A.size(-2),
  756. lambda: (
  757. f"A must be batches of square matrices, "
  758. f"but they are {A.size(-2)} by {A.size(-1)} matrices"
  759. ),
  760. )
  761. torch._check(
  762. A.size(-1) == self.size(-2),
  763. lambda: (
  764. f"Incompatible matrix sizes for {name}: each A "
  765. f"matrix is {A.size(-1)} by {A.size(-1)}"
  766. f" but each b matrix is {self.size(-2)} by {self.size(-1)}"
  767. ),
  768. )
  769. # From aten/src/ATen/native/LinearAlgebraUtils.h
  770. def checkFloatingOrComplex(
  771. t: Tensor,
  772. f_name: str,
  773. allow_low_precision_dtypes: bool = True,
  774. ):
  775. dtype = t.dtype
  776. torch._check(
  777. t.is_floating_point() or t.is_complex(),
  778. lambda: f"{f_name}: Expected a floating point or complex tensor as input. Got {dtype}",
  779. )
  780. if not allow_low_precision_dtypes:
  781. torch._check(
  782. dtype in (torch.float, torch.double, torch.cfloat, torch.cdouble),
  783. lambda: f"{f_name}: Low precision dtypes not supported. Got {dtype}",
  784. )
  785. # From aten/src/ATen/native/LinearAlgebraUtils.h
  786. def checkIsMatrix(A: Tensor, f_name: str, arg_name: str = "A"):
  787. torch._check(
  788. A.dim() >= 2,
  789. lambda: f"{f_name}: The input tensor {arg_name} must have at least 2 dimensions.",
  790. )
  791. def checkInputsSolver(A: Tensor, B: Tensor, left: bool, f_name: str):
  792. squareCheckInputs(A, f_name)
  793. checkIsMatrix(B, f_name)
  794. torch._check(
  795. A.size(-2) == B.size(-2) if left else A.size(-1) == B.size(-1),
  796. lambda: (
  797. f"{f_name}: Incompatible shapes of A and B for the equation "
  798. f"{'AX = B' if left else 'XA = B'}"
  799. f" ({A.size(-2)}x{A.size(-1)} and {B.size(-2)}x{B.size(-1)})"
  800. ),
  801. )
  802. def checkSameDevice(
  803. fn_name: str,
  804. result: Tensor,
  805. input: Tensor,
  806. result_name: str = "result",
  807. ):
  808. torch._check(
  809. result.device == input.device,
  810. lambda: (
  811. f"{fn_name}: Expected {result_name} and input tensors to be on the same device, but got "
  812. f"{result_name} on {result.device} and input on {input.device}"
  813. ),
  814. )
  815. def checkUplo(UPLO: str):
  816. UPLO_uppercase = UPLO.upper()
  817. torch._check(
  818. len(UPLO) == 1 and (UPLO_uppercase == "U" or UPLO_uppercase == "L"),
  819. lambda: f"Expected UPLO argument to be 'L' or 'U', but got {UPLO}",
  820. )
  821. @register_meta([aten._linalg_eigh.default, aten._linalg_eigh.eigenvalues])
  822. @out_wrapper("eigenvalues", "eigenvectors")
  823. def meta__linalg_eigh(A: Tensor, UPLO: str = "L", compute_v: bool = True):
  824. squareCheckInputs(A, "linalg.eigh")
  825. checkUplo(UPLO)
  826. shape = list(A.shape)
  827. if compute_v:
  828. vecs = A.new_empty(shape)
  829. vecs.as_strided_(shape, make_contiguous_strides_for(shape, row_major=False))
  830. else:
  831. vecs = A.new_empty([0])
  832. shape.pop()
  833. vals = A.new_empty(shape, dtype=toRealValueType(A.dtype))
  834. return vals, vecs
  835. @register_meta([aten._linalg_eigvals.default, aten.linalg_eigvals.out])
  836. @out_wrapper()
  837. def meta__linalg_eigvals(input: Tensor) -> Tensor:
  838. squareCheckInputs(input, "linalg.eigvals")
  839. complex_dtype = (
  840. input.dtype
  841. if utils.is_complex_dtype(input.dtype)
  842. else utils.corresponding_complex_dtype(input.dtype)
  843. )
  844. return input.new_empty(input.shape[:-1], dtype=complex_dtype)
  845. @register_meta([aten.linalg_eig])
  846. @out_wrapper("eigenvalues", "eigenvectors")
  847. def meta_linalg_eig(input: Tensor):
  848. squareCheckInputs(input, "linalg.eig")
  849. complex_dtype = (
  850. input.dtype
  851. if utils.is_complex_dtype(input.dtype)
  852. else utils.corresponding_complex_dtype(input.dtype)
  853. )
  854. values = input.new_empty(input.shape[:-1], dtype=complex_dtype)
  855. vectors = input.new_empty(input.shape, dtype=complex_dtype)
  856. is_cuda = device_hint(input) == "cuda"
  857. vectors.as_strided_(
  858. input.shape, make_contiguous_strides_for(input.shape, row_major=is_cuda)
  859. )
  860. return values, vectors
  861. def cloneBatchedColumnMajor(src: Tensor) -> Tensor:
  862. return src.mT.clone(memory_format=torch.contiguous_format).transpose(-2, -1)
  863. @register_meta(aten._cholesky_solve_helper)
  864. @out_wrapper()
  865. def _cholesky_solve_helper(self: Tensor, A: Tensor, upper: bool) -> Tensor:
  866. return cloneBatchedColumnMajor(self)
  867. @register_meta(aten.cholesky_solve)
  868. @out_wrapper()
  869. def cholesky_solve(self: Tensor, A: Tensor, upper: bool = False) -> Tensor:
  870. torch._check(
  871. self.ndim >= 2,
  872. lambda: f"b should have at least 2 dimensions, but has {self.ndim} dimensions instead",
  873. )
  874. torch._check(
  875. A.ndim >= 2,
  876. lambda: f"u should have at least 2 dimensions, but has {A.ndim} dimensions instead",
  877. )
  878. self_broadcasted, A_broadcasted = _linalg_broadcast_batch_dims_name(
  879. self, A, "cholesky_solve"
  880. )
  881. return _cholesky_solve_helper(self_broadcasted, A_broadcasted, upper)
  882. @register_meta(aten.cholesky)
  883. @out_wrapper()
  884. def cholesky(self: Tensor, upper: bool = False) -> Tensor:
  885. if self.numel() == 0:
  886. return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
  887. squareCheckInputs(self, "cholesky")
  888. return cloneBatchedColumnMajor(self)
  889. @register_meta(aten.cholesky_inverse)
  890. @out_wrapper()
  891. def cholesky_inverse(self: Tensor, upper: bool = False) -> Tensor:
  892. squareCheckInputs(self, "cholesky_inverse")
  893. return cloneBatchedColumnMajor(self)
  894. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  895. @register_meta(aten.linalg_cholesky_ex.default)
  896. def linalg_cholesky_ex(A: Tensor, upper: bool = False, check_errors: bool = False):
  897. squareCheckInputs(A, "linalg.cholesky")
  898. checkFloatingOrComplex(A, "linalg.cholesky")
  899. A_shape = A.shape
  900. ndim = len(A_shape)
  901. # L
  902. L_strides = make_contiguous_strides_for(A_shape, False)
  903. L = A.new_empty(A_shape)
  904. L.as_strided_(A_shape, L_strides)
  905. # infos
  906. infos = A.new_empty(A_shape[0 : ndim - 2], dtype=torch.int32)
  907. return L, infos
  908. @register_meta(
  909. [aten.linalg_householder_product.default, aten.linalg_householder_product.out]
  910. )
  911. @out_wrapper()
  912. def linalg_householder_product(input: Tensor, tau: Tensor) -> Tensor:
  913. torch._check(
  914. input.ndim >= 2,
  915. lambda: "torch.linalg.householder_product: input must have at least 2 dimensions.",
  916. )
  917. torch._check(
  918. input.size(-2) >= input.size(-1),
  919. lambda: "torch.linalg.householder_product: input.shape[-2] must be greater than or equal to input.shape[-1]",
  920. )
  921. torch._check(
  922. input.size(-1) >= tau.size(-1),
  923. lambda: "torch.linalg.householder_product: input.shape[-1] must be greater than or equal to tau.shape[-1]",
  924. )
  925. torch._check(
  926. input.ndim - tau.ndim == 1,
  927. lambda: (
  928. f"torch.linalg.householder_product: Expected tau to have one dimension less than input, "
  929. f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
  930. ),
  931. )
  932. if input.ndim > 2:
  933. expected_batch_tau_shape = input.shape[:-2]
  934. actual_batch_tau_shape = tau.shape[:-1]
  935. torch._check(
  936. actual_batch_tau_shape == expected_batch_tau_shape,
  937. lambda: (
  938. f"torch.linalg.householder_product: Expected batch dimensions of tau to be "
  939. f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
  940. ),
  941. )
  942. torch._check(
  943. tau.dtype == input.dtype,
  944. lambda: (
  945. f"torch.linalg.householder_product: tau dtype {tau.dtype}"
  946. f" does not match input dtype {input.dtype}"
  947. ),
  948. )
  949. checkSameDevice("torch.linalg.householder_product", tau, input, "tau")
  950. return torch.empty_strided(
  951. size=input.shape,
  952. stride=make_contiguous_strides_for(input.shape, row_major=False),
  953. dtype=input.dtype,
  954. device=input.device,
  955. )
  956. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  957. @register_meta(aten.linalg_inv_ex.default)
  958. def linalg_inv_ex_meta(A: Tensor, check_errors: bool = False):
  959. squareCheckInputs(A, "linalg.inv_ex")
  960. checkFloatingOrComplex(A, "linalg.inv_ex", allow_low_precision_dtypes=False)
  961. L = A.new_empty(A.shape)
  962. L.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
  963. infos = A.new_empty(A.shape[:-2], dtype=torch.int32)
  964. return L, infos
  965. @register_meta([aten.linalg_ldl_factor_ex.default, aten.linalg_ldl_factor_ex.out])
  966. @out_wrapper("LD", "pivots", "info")
  967. def linalg_ldl_factor_ex_meta(
  968. self: Tensor,
  969. *,
  970. hermitian: bool = False,
  971. check_errors: bool = False,
  972. ) -> tuple[Tensor, Tensor, Tensor]:
  973. squareCheckInputs(self, "torch.linalg.ldl_factor_ex")
  974. checkFloatingOrComplex(self, "torch.linalg.ldl_factor_ex")
  975. LD = torch.empty_strided(
  976. size=self.shape,
  977. stride=make_contiguous_strides_for(self.shape, row_major=False),
  978. dtype=self.dtype,
  979. device=self.device,
  980. )
  981. pivots = self.new_empty(self.shape[:-1], dtype=torch.int)
  982. info = self.new_empty(self.shape[:-2], dtype=torch.int)
  983. return LD, pivots, info
  984. @register_meta([aten.linalg_ldl_solve.default, aten.linalg_ldl_solve.out])
  985. @out_wrapper()
  986. def linalg_ldl_solve_meta(
  987. LD: Tensor,
  988. pivots: Tensor,
  989. B: Tensor,
  990. *,
  991. hermitian: bool = False,
  992. ) -> Tensor:
  993. squareCheckInputs(LD, "torch.linalg.ldl_solve")
  994. checkFloatingOrComplex(LD, "torch.linalg.ldl_solve")
  995. linearSolveCheckInputs(B, LD, "torch.linalg.ldl_solve")
  996. torch._check(
  997. B.ndim >= 2,
  998. lambda: (
  999. f"torch.linalg.ldl_solve: Expected B to have at least 2 dimensions, "
  1000. f"but it has {B.ndim} dimensions instead"
  1001. ),
  1002. )
  1003. expected_pivots_shape = LD.shape[:-1]
  1004. torch._check(
  1005. expected_pivots_shape == pivots.shape,
  1006. lambda: (
  1007. f"torch.linalg.ldl_solve: Expected LD.shape[:-1] and pivots.shape to be the same, "
  1008. f"but got pivots with shape {pivots.shape} instead"
  1009. ),
  1010. )
  1011. torch._check(
  1012. utils.is_integer_dtype(pivots.dtype),
  1013. lambda: f"torch.linalg.ldl_solve: Expected pivots to be integers. Got {pivots.dtype}",
  1014. )
  1015. torch._check(
  1016. LD.dtype == B.dtype,
  1017. lambda: f"torch.linalg.ldl_solve: LD dtype {LD.dtype} does not match b dtype {B.dtype}",
  1018. )
  1019. B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LD)
  1020. return torch.empty_strided(
  1021. size=B_broadcast_size,
  1022. stride=make_contiguous_strides_for(B_broadcast_size, row_major=False),
  1023. dtype=B.dtype,
  1024. device=B.device,
  1025. )
  1026. @register_meta([aten.linalg_lu.default, aten.linalg_lu.out])
  1027. @out_wrapper("P", "L", "U")
  1028. def linalg_lu_meta(A: Tensor, *, pivot: bool = True) -> tuple[Tensor, Tensor, Tensor]:
  1029. torch._check(
  1030. A.ndim >= 2,
  1031. lambda: f"linalg.lu: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
  1032. )
  1033. sizes = list(A.shape)
  1034. m = sizes[-2]
  1035. n = sizes[-1]
  1036. k = min(m, n)
  1037. sizes[-1] = m
  1038. if pivot:
  1039. P = A.new_empty(sizes)
  1040. else:
  1041. P = A.new_empty([0])
  1042. sizes[-1] = k
  1043. L = A.new_empty(sizes)
  1044. sizes[-2] = k
  1045. sizes[-1] = n
  1046. U = A.new_empty(sizes)
  1047. return P, L, U
  1048. @register_meta([aten.linalg_lu_factor_ex.default, aten.linalg_lu_factor_ex.out])
  1049. @out_wrapper("LU", "pivots", "info")
  1050. def linalg_lu_factor_ex_meta(
  1051. A: Tensor,
  1052. *,
  1053. pivot: bool = True,
  1054. check_errors: bool = False,
  1055. ) -> tuple[Tensor, Tensor, Tensor]:
  1056. torch._check(
  1057. A.ndim >= 2,
  1058. lambda: f"torch.lu_factor: Expected tensor with 2 or more dimensions. Got size: {A.shape} instead",
  1059. )
  1060. sizes = list(A.shape)
  1061. m = sizes[-2]
  1062. n = sizes[-1]
  1063. LU = torch.empty_strided(
  1064. size=sizes,
  1065. stride=make_contiguous_strides_for(sizes, row_major=False),
  1066. dtype=A.dtype,
  1067. device=A.device,
  1068. )
  1069. # Sets sizes to the size of pivots
  1070. sizes.pop()
  1071. sizes[-1] = min(m, n)
  1072. pivots = A.new_empty(sizes, dtype=torch.int)
  1073. # Sets sizes to the size of info
  1074. sizes.pop()
  1075. info = A.new_empty(sizes, dtype=torch.int)
  1076. return LU, pivots, info
  1077. @register_meta([aten.linalg_lu_solve.default, aten.linalg_lu_solve.out])
  1078. @out_wrapper()
  1079. def linalg_lu_solve_meta(
  1080. LU: Tensor,
  1081. pivots: Tensor,
  1082. B: Tensor,
  1083. *,
  1084. left: bool = True,
  1085. adjoint: bool = False,
  1086. ) -> Tensor:
  1087. # dtype
  1088. checkFloatingOrComplex(LU, "torch.linalg.lu_solve")
  1089. torch._check(
  1090. LU.dtype == B.dtype,
  1091. lambda: (
  1092. f"linalg.lu_solve: Expected LU and B to have the same dtype, "
  1093. f"but found LU of type {LU.dtype} and B of type {B.dtype} instead"
  1094. ),
  1095. )
  1096. torch._check(
  1097. pivots.dtype == torch.int,
  1098. lambda: "linalg.lu_solve: pivots should be a Tensor of scalar type torch.int32",
  1099. )
  1100. # matrix shapes
  1101. squareCheckInputs(LU, "torch.linalg.lu_solve")
  1102. checkInputsSolver(LU, B, left, "linalg.lu_solve")
  1103. torch._check(
  1104. LU.size(-1) == pivots.size(-1),
  1105. lambda: "linalg.lu_solve: Number of pivots per batch should be same as the dimension of the matrix",
  1106. )
  1107. # batches
  1108. torch._check(
  1109. LU.shape[:-1] == pivots.shape,
  1110. lambda: (
  1111. f"linalg.lu_solve: Expected LU.shape[:-1] and pivots.shape to be the same, "
  1112. f"but got pivots with shape {pivots.shape} instead"
  1113. ),
  1114. )
  1115. B_broadcast_size, _ = _linalg_broadcast_batch_dims(B, LU)
  1116. result = torch.empty_strided(
  1117. size=B_broadcast_size,
  1118. stride=make_contiguous_strides_for(B_broadcast_size, row_major=not left),
  1119. dtype=B.dtype,
  1120. device=B.device,
  1121. )
  1122. if result.numel() != 0 and not left:
  1123. if result.is_complex():
  1124. result = result.conj()
  1125. return result
  1126. @register_meta(aten.lu_unpack)
  1127. @out_wrapper("P", "L", "U")
  1128. def lu_unpack_meta(
  1129. LU: Tensor,
  1130. pivots: Tensor,
  1131. unpack_data: bool = True,
  1132. unpack_pivots: bool = True,
  1133. ) -> tuple[Tensor, Tensor, Tensor]:
  1134. torch._check(
  1135. LU.ndim >= 2,
  1136. lambda: f"torch.lu_unpack: Expected tensor with 2 or more dimensions. Got size: {LU.shape} instead",
  1137. )
  1138. if unpack_pivots:
  1139. torch._check(
  1140. pivots.dtype == torch.int32,
  1141. lambda: (
  1142. "torch.lu_unpack: LU_pivots is expected to be a contiguous tensor of torch.int32 dtype.\n"
  1143. "Note: this function is intended to be used with the output produced by torch.linalg.lu_factor"
  1144. ),
  1145. )
  1146. sizes = list(LU.shape)
  1147. m = sizes[-2]
  1148. n = sizes[-1]
  1149. k = min(m, n)
  1150. sizes[-1] = m
  1151. if unpack_pivots:
  1152. P = LU.new_empty(sizes)
  1153. else:
  1154. P = LU.new_empty([0])
  1155. if unpack_data:
  1156. sizes[-1] = k
  1157. L = LU.new_empty(sizes)
  1158. sizes[-2] = k
  1159. sizes[-1] = n
  1160. U = LU.new_empty(sizes)
  1161. else:
  1162. L = LU.new_empty([0])
  1163. U = LU.new_empty([0])
  1164. return P, L, U
  1165. # parse the "mode" param in linalg_qr: return a tuple of bools (compute_q, reduced)
  1166. def _parse_qr_mode(mode: str) -> tuple[bool, bool]:
  1167. if mode == "reduced":
  1168. compute_q = True
  1169. reduced = True
  1170. elif mode == "complete":
  1171. compute_q = True
  1172. reduced = False
  1173. elif mode == "r":
  1174. compute_q = False
  1175. reduced = True # this is actually irrelevant in this mode
  1176. else:
  1177. torch._check(
  1178. False,
  1179. lambda: (
  1180. f"qr received unrecognized mode '{mode}' "
  1181. f"but expected one of 'reduced' (default), 'r', or 'complete'"
  1182. ),
  1183. )
  1184. return compute_q, reduced # type: ignore[possibly-undefined]
  1185. @register_meta([aten.linalg_qr.default, aten.linalg_qr.out])
  1186. @out_wrapper("Q", "R")
  1187. def linalg_qr_meta(A: Tensor, mode: str = "reduced") -> tuple[Tensor, Tensor]:
  1188. checkIsMatrix(A, "linalg.qr")
  1189. checkFloatingOrComplex(A, "linalg.qr")
  1190. compute_q, reduced_mode = _parse_qr_mode(mode)
  1191. m = A.shape[-2]
  1192. n = A.shape[-1]
  1193. k = min(m, n)
  1194. if compute_q:
  1195. Q_shape = list(A.shape)
  1196. Q_shape[-1] = k if reduced_mode else m
  1197. Q = A.new_empty(Q_shape)
  1198. Q.as_strided_(Q_shape, make_contiguous_strides_for(Q_shape, row_major=False))
  1199. else:
  1200. Q = A.new_empty([0])
  1201. # For readability
  1202. R_shape = list(A.shape)
  1203. R_shape[-2] = k if reduced_mode or not compute_q else m
  1204. R = A.new_empty(R_shape)
  1205. R.as_strided_(R_shape, make_contiguous_strides_for(R_shape, row_major=False))
  1206. return Q, R
  1207. @register_meta([aten._linalg_slogdet.default, aten._linalg_slogdet.sign])
  1208. @out_wrapper("sign", "logabsdet", "LU", "pivots")
  1209. def _linalg_slogdet(A: Tensor) -> tuple[Tensor, Tensor, Tensor, Tensor]:
  1210. squareCheckInputs(A, "linalg.slogdet")
  1211. checkFloatingOrComplex(A, "linalg.slogdet", False)
  1212. shape = A.shape
  1213. sign = A.new_empty(shape[:-2])
  1214. logabsdet = A.new_empty(shape[:-2], dtype=toRealValueType(A.dtype))
  1215. LU = torch.empty_strided(
  1216. size=shape,
  1217. stride=make_contiguous_strides_for(shape, False),
  1218. dtype=A.dtype,
  1219. device=A.device,
  1220. )
  1221. pivots = A.new_empty(shape[:-1], dtype=torch.int32)
  1222. return sign, logabsdet, LU, pivots
  1223. # From aten/src/ATen/native/BatchLinearAlgebra.cpp
  1224. # NOTE: matching defaults in aten/src/ATen/native/native_functions.yaml
  1225. @register_meta(aten._linalg_svd.default)
  1226. def _linalg_svd_meta(
  1227. A: Tensor,
  1228. full_matrices: bool = False,
  1229. compute_uv: bool = True,
  1230. driver: str | None = None,
  1231. ):
  1232. checkIsMatrix(A, "linalg.svd")
  1233. checkFloatingOrComplex(A, "linalg.svd")
  1234. batch_dims = list(A.shape[:-2])
  1235. m = A.shape[-2]
  1236. n = A.shape[-1]
  1237. k = min(m, n)
  1238. if compute_uv:
  1239. U_shape = batch_dims + [m, m if full_matrices else k]
  1240. U = A.new_empty(U_shape)
  1241. U.as_strided_(U_shape, make_contiguous_strides_for(U_shape, row_major=False))
  1242. V_shape = batch_dims + [n if full_matrices else k, n]
  1243. V = A.new_empty(V_shape)
  1244. # NB: This checks for CUDA since there is no way to check for cuSolver.
  1245. # Also, this might not work correctly on CPU when fake_device is not
  1246. # available as device_hint just defaults to CUDA in that case. See
  1247. # _linalg_svd meta in core.
  1248. is_cuda = device_hint(A) == "cuda"
  1249. V.as_strided_(V_shape, make_contiguous_strides_for(V_shape, row_major=is_cuda))
  1250. else:
  1251. # doesn't matter
  1252. U = A.new_empty([0])
  1253. V = A.new_empty([0])
  1254. # S is always real, even when A is complex.
  1255. S = A.new_empty(batch_dims + [k], dtype=toRealValueType(A.dtype))
  1256. return U, S, V
  1257. def _linalg_broadcast_batch_dims(
  1258. arg1: Tensor,
  1259. arg2: Tensor,
  1260. ) -> tuple[list[int], list[int]]:
  1261. # broadcast the batch dimensions of arg1 and arg2.
  1262. arg1_batch_sizes = arg1.shape[:-2]
  1263. arg2_batch_sizes = arg2.shape[:-2]
  1264. expand_batch_portion = _broadcast_shapes(arg1_batch_sizes, arg2_batch_sizes)
  1265. arg1_expand_size = list(expand_batch_portion)
  1266. arg1_expand_size += [arg1.size(-2), arg1.size(-1)]
  1267. arg2_expand_size = list(expand_batch_portion)
  1268. arg2_expand_size += [arg2.size(-2), arg2.size(-1)]
  1269. return arg1_expand_size, arg2_expand_size
  1270. def _linalg_broadcast_batch_dims_name(
  1271. arg1: Tensor,
  1272. arg2: Tensor,
  1273. name: str | None,
  1274. ) -> tuple[Tensor, Tensor]:
  1275. # If there's no name we assume we don't want to check the errors
  1276. if name:
  1277. linearSolveCheckInputs(arg1, arg2, name)
  1278. arg1_expand_size, arg2_expand_size = _linalg_broadcast_batch_dims(arg1, arg2)
  1279. arg1_broadcasted = (
  1280. arg1 if arg1_expand_size == arg1.shape else arg1.expand(arg1_expand_size)
  1281. )
  1282. arg2_broadcasted = (
  1283. arg2 if arg2_expand_size == arg2.shape else arg2.expand(arg2_expand_size)
  1284. )
  1285. return arg1_broadcasted, arg2_broadcasted
  1286. def linalg_solve_is_vector_rhs(input: Tensor, other: Tensor) -> bool:
  1287. expected_batched_rhs_shape = input.shape[:-1]
  1288. vector_case = other.ndim == 1 or (
  1289. input.ndim - 1 == other.ndim and other.shape == expected_batched_rhs_shape
  1290. )
  1291. return vector_case
  1292. @register_meta(aten._linalg_solve_ex)
  1293. def _linalg_solve_ex(
  1294. A: Tensor,
  1295. B: Tensor,
  1296. *,
  1297. left: bool = True,
  1298. check_errors: bool = False,
  1299. result: Tensor | None = None,
  1300. LU: Tensor | None = None,
  1301. pivots: Tensor | None = None,
  1302. info: Tensor | None = None,
  1303. ) -> tuple[Tensor, Tensor, Tensor, Tensor]:
  1304. checkFloatingOrComplex(A, "linalg.solve")
  1305. torch._check(
  1306. A.dtype == B.dtype,
  1307. lambda: (
  1308. f"linalg.solve: Expected A and B to have the same dtype, but found A of type "
  1309. f"{A.dtype} and B of type {B.dtype} instead"
  1310. ),
  1311. )
  1312. vector_case = linalg_solve_is_vector_rhs(A, B)
  1313. B_ = B.unsqueeze(-1) if vector_case else B
  1314. checkInputsSolver(A, B_, left, "linalg.solve")
  1315. B_broad_shape, _ = _linalg_broadcast_batch_dims(B_, A)
  1316. torch._check(
  1317. left or not vector_case,
  1318. lambda: (
  1319. "linalg.solve: Vector broadcasting of the left hand side is not supported for left=False. "
  1320. "In this case linalg.solve is equivalent to B / A.squeeze(-1)"
  1321. ),
  1322. )
  1323. result_shape = B_broad_shape[:-1] if vector_case else B_broad_shape
  1324. result_ = torch.empty_strided(
  1325. size=result_shape,
  1326. stride=make_contiguous_strides_for(result_shape, not left),
  1327. dtype=B.dtype,
  1328. device=B.device,
  1329. )
  1330. shape = A.shape
  1331. LU_ = torch.empty_strided(
  1332. size=shape,
  1333. stride=make_contiguous_strides_for(shape, False),
  1334. dtype=A.dtype,
  1335. device=A.device,
  1336. )
  1337. pivots_ = A.new_empty(shape[:-1], dtype=torch.int32)
  1338. info_ = A.new_empty(shape[:-2], dtype=torch.int32)
  1339. out = (result, LU, pivots, info)
  1340. res = (result_, LU_, pivots_, info_)
  1341. if all(x is not None for x in out):
  1342. for r, o in zip(res, out):
  1343. # resize and copy operations are done in-place
  1344. _maybe_resize_out(o, r.shape) # type: ignore[arg-type]
  1345. # strides are not copied in out_wrapper
  1346. o.as_strided_(r.shape, r.stride()) # type: ignore[union-attr]
  1347. _safe_copy_out(copy_from=r, copy_to=o, exact_dtype=False) # type: ignore[arg-type]
  1348. return res
  1349. @register_meta([aten.linalg_solve_triangular.default, aten.linalg_solve_triangular.out])
  1350. def linalg_solve_triangular_meta(
  1351. A: Tensor,
  1352. B: Tensor,
  1353. *,
  1354. upper: bool,
  1355. left: bool = True,
  1356. unitriangular: bool = False,
  1357. out: Tensor | None = None,
  1358. ) -> Tensor:
  1359. if out is None:
  1360. out = A.new_empty([0])
  1361. assert isinstance(out, TensorLike)
  1362. checkInputsSolver(A, B, left, "linalg.solve_triangular")
  1363. B_, A_ = _linalg_broadcast_batch_dims_name(B, A, None)
  1364. avoid_copy_A = A_.transpose(-2, -1).is_contiguous() and A_.is_conj()
  1365. if avoid_copy_A:
  1366. out = _maybe_resize_out(out, B_.shape)
  1367. else:
  1368. # reimplementation of resize_output with result F-contig
  1369. if _resize_output_check(out, B_.shape):
  1370. out.resize_(B_.transpose(-2, -1).shape)
  1371. out.transpose_(-2, -1)
  1372. return out # type: ignore[return-value]
  1373. @register_meta(aten.triangular_solve)
  1374. @out_wrapper("X", "M", exact_dtype=True)
  1375. def triangular_solve_meta(
  1376. self: Tensor,
  1377. A: Tensor,
  1378. upper: bool = True,
  1379. transpose: bool = False,
  1380. unitriangular: bool = False,
  1381. ) -> tuple[Tensor, Tensor]:
  1382. torch._check(
  1383. self.ndim >= 2,
  1384. lambda: (
  1385. f"torch.triangular_solve: Expected b to have at least 2 dimensions, "
  1386. f"but it has {self.ndim} dimensions instead"
  1387. ),
  1388. )
  1389. torch._check(
  1390. A.ndim >= 2,
  1391. lambda: (
  1392. f"torch.triangular_solve: Expected A to have at least 2 dimensions, "
  1393. f"but it has {A.ndim} dimensions instead"
  1394. ),
  1395. )
  1396. linearSolveCheckInputs(self, A, "triangular_solve")
  1397. if A.layout == torch.strided:
  1398. self_broadcast_size, A_broadcast_size = _linalg_broadcast_batch_dims(self, A)
  1399. solution = torch.empty_strided(
  1400. size=self_broadcast_size,
  1401. stride=make_contiguous_strides_for(self_broadcast_size, row_major=False),
  1402. dtype=self.dtype,
  1403. device=self.device,
  1404. )
  1405. cloned_coefficient = torch.empty_strided(
  1406. size=A_broadcast_size,
  1407. stride=make_contiguous_strides_for(A_broadcast_size, row_major=False),
  1408. dtype=A.dtype,
  1409. device=A.device,
  1410. )
  1411. elif A.layout == torch.sparse_csr or A.layout == torch.sparse_bsr:
  1412. solution = torch.empty_like(self)
  1413. cloned_coefficient = self.new_empty([0])
  1414. else:
  1415. torch._check(False, lambda: "triangular_solve: Got an unexpected layout.")
  1416. return solution, cloned_coefficient # type: ignore[possibly-undefined]
  1417. # From aten/src/ATen/native/LinearAlgebra.cpp
  1418. @register_meta(aten._linalg_det.default)
  1419. def _linalg_det_meta(A):
  1420. squareCheckInputs(A, "linalg.det")
  1421. checkFloatingOrComplex(A, "linalg.det")
  1422. det = A.new_empty(A.shape[:-2])
  1423. LU = A.new_empty(A.shape)
  1424. LU.as_strided_(A.shape, make_contiguous_strides_for(A.shape, row_major=False))
  1425. pivots = A.new_empty(A.shape[:-1], dtype=torch.int32)
  1426. return det, LU, pivots
  1427. @register_meta(aten.ormqr)
  1428. @out_wrapper()
  1429. def ormqr(
  1430. input: Tensor,
  1431. tau: Tensor,
  1432. other: Tensor,
  1433. left: bool = True,
  1434. transpose: bool = False,
  1435. ) -> Tensor:
  1436. torch._check(
  1437. input.ndim >= 2, lambda: "torch.ormqr: input must have at least 2 dimensions."
  1438. )
  1439. torch._check(
  1440. other.ndim >= 2, lambda: "torch.ormqr: other must have at least 2 dimensions."
  1441. )
  1442. left_size_condition = -2 if left else -1
  1443. torch._check(
  1444. other.shape[left_size_condition] >= tau.shape[-1],
  1445. lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be greater than or equal to tau.shape[-1]",
  1446. )
  1447. torch._check(
  1448. other.shape[left_size_condition] == input.shape[-2],
  1449. lambda: f"torch.ormqr: other.shape[{left_size_condition}] must be equal to input.shape[-2]",
  1450. )
  1451. torch._check(
  1452. tau.shape[-1] <= input.shape[-1],
  1453. lambda: "torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]",
  1454. )
  1455. torch._check(
  1456. input.ndim - tau.ndim == 1,
  1457. lambda: (
  1458. f"torch.ormqr: Expected tau to have one dimension less than input, "
  1459. f"but got tau.ndim equal to {tau.ndim} and input.ndim is equal to {input.ndim}"
  1460. ),
  1461. )
  1462. torch._check(
  1463. input.ndim == other.ndim,
  1464. lambda: (
  1465. f"torch.ormqr: Expected other to have the same number of dimensions as input, "
  1466. f"but got other.ndim equal to {other.ndim} and input.ndim is equal to {input.ndim}"
  1467. ),
  1468. )
  1469. if input.ndim > 2:
  1470. expected_batch_shape = input.shape[:-2]
  1471. actual_batch_tau_shape = tau.shape[:-1]
  1472. torch._check(
  1473. actual_batch_tau_shape == expected_batch_shape,
  1474. lambda: (
  1475. f"torch.ormqr: Expected batch dimensions of tau to be "
  1476. f"equal to input.shape[:-2], but got {actual_batch_tau_shape}"
  1477. ),
  1478. )
  1479. actual_batch_other_shape = other.shape[:-2]
  1480. torch._check(
  1481. actual_batch_other_shape == expected_batch_shape,
  1482. lambda: (
  1483. f"torch.ormqr: Expected batch dimensions of other to be "
  1484. f"equal to input.shape[:-2], but got {actual_batch_other_shape}"
  1485. ),
  1486. )
  1487. torch._check(
  1488. tau.dtype == input.dtype,
  1489. lambda: (
  1490. f"torch.ormqr: Expected input and tau to have the same dtype, "
  1491. f"but input has dtype {input.dtype} and tau has dtype {tau.dtype}"
  1492. ),
  1493. )
  1494. torch._check(
  1495. other.dtype == input.dtype,
  1496. lambda: (
  1497. f"torch.ormqr: Expected input and other to have the same dtype, "
  1498. f"but input has dtype {input.dtype} and other has dtype {other.dtype}"
  1499. ),
  1500. )
  1501. checkSameDevice("torch.ormqr", tau, input, "tau")
  1502. checkSameDevice("torch.ormqr", other, input, "other")
  1503. return torch.empty_strided(
  1504. size=other.shape,
  1505. stride=make_contiguous_strides_for(other.shape, row_major=False),
  1506. dtype=other.dtype,
  1507. device=other.device,
  1508. )
  1509. def _padding_check_valid_input(input, padding, *, dim):
  1510. torch._check(
  1511. len(padding) == 2 * dim,
  1512. lambda: f"padding size is expected to be {2 * dim}, but got: {len(padding)}",
  1513. )
  1514. input_dim = input.ndim
  1515. is_batch_mode = input_dim == (dim + 2)
  1516. valid_batch_mode = is_batch_mode
  1517. valid_non_batch_mode = not is_batch_mode
  1518. if is_batch_mode:
  1519. # allow batch size of 0-dim.
  1520. for d in range(1, input_dim):
  1521. valid_batch_mode = valid_batch_mode and input.size(d) != 0
  1522. else:
  1523. for d in range(input_dim):
  1524. valid_non_batch_mode = valid_non_batch_mode and input.size(d) != 0
  1525. # allow empty batch size but not other dimensions.
  1526. torch._check(
  1527. valid_batch_mode or valid_non_batch_mode,
  1528. lambda: (
  1529. f"Expected {dim + 1}D or {dim + 2}D (batch mode) tensor with possibly 0 batch size "
  1530. f"and other non-zero dimensions for input, but got: {input.shape}"
  1531. ),
  1532. )
  1533. def _pad1d_common(input, padding, *, is_reflection):
  1534. dim_plane = 0
  1535. dim_w = 1
  1536. nbatch = 1
  1537. if input.ndim == 3:
  1538. nbatch = input.size(0)
  1539. dim_w += 1
  1540. dim_plane += 1
  1541. _padding_check_valid_input(input, padding, dim=1)
  1542. pad_l, pad_r = padding
  1543. nplane = input.size(dim_plane)
  1544. input_w = input.size(dim_w)
  1545. output_w = input_w + pad_l + pad_r
  1546. if is_reflection:
  1547. torch._check(
  1548. pad_l < input_w and pad_r < input_w,
  1549. lambda: (
  1550. f"Argument #4: Padding size should be less than the corresponding input dimension, "
  1551. f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
  1552. ),
  1553. )
  1554. torch._check(
  1555. output_w >= 1,
  1556. lambda: f"input (W: {input_w}) is too small. Calculated output W: {output_w}",
  1557. )
  1558. if input.ndim == 2:
  1559. return input.new_empty((nplane, output_w))
  1560. else:
  1561. return input.new_empty((nbatch, nplane, output_w))
  1562. @register_meta(aten.reflection_pad1d)
  1563. @out_wrapper()
  1564. def meta_reflection_pad1d(input, padding):
  1565. return _pad1d_common(input, padding, is_reflection=True)
  1566. @register_meta(aten.replication_pad1d)
  1567. @out_wrapper()
  1568. def meta_replication_pad1d(input, padding):
  1569. torch._check(
  1570. input.dtype != torch.bool,
  1571. lambda: f""""replication_pad1d" not implemented for '{input.dtype.__str__()}'""",
  1572. )
  1573. return _pad1d_common(input, padding, is_reflection=False)
  1574. def _pad1d_backward_common(grad_output, input, padding, *, is_reflection):
  1575. dim_w = 1
  1576. if not is_reflection:
  1577. torch._check(len(padding) == 2, lambda: "padding size is expected to be 2")
  1578. if input.ndim == 3:
  1579. dim_w += 1
  1580. pad_l, pad_r = padding
  1581. input_w = input.size(dim_w)
  1582. output_w = input_w + pad_l + pad_r
  1583. if is_reflection:
  1584. torch._check(
  1585. pad_l < input_w and pad_r < input_w,
  1586. lambda: (
  1587. f"Argument #4: Padding size should be less than the corresponding input dimension, "
  1588. f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
  1589. ),
  1590. )
  1591. torch._check(
  1592. output_w == grad_output.size(dim_w),
  1593. lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
  1594. )
  1595. return input.new_empty(input.shape)
  1596. @register_meta(aten.reflection_pad1d_backward)
  1597. @out_wrapper("grad_input")
  1598. def meta_reflection_pad1d_backward(grad_output, input, padding):
  1599. return _pad1d_backward_common(grad_output, input, padding, is_reflection=True)
  1600. @register_meta(aten.replication_pad1d_backward)
  1601. @out_wrapper("grad_input")
  1602. def meta_replication_pad1d_backward(grad_output, input, padding):
  1603. return _pad1d_backward_common(grad_output, input, padding, is_reflection=False)
  1604. def _pad2d_common(input, padding, *, is_reflection):
  1605. dim_w = 2
  1606. dim_h = 1
  1607. dim_slices = 0
  1608. nbatch = 1
  1609. _padding_check_valid_input(input, padding, dim=2)
  1610. ndim = input.ndim
  1611. if ndim == 4:
  1612. nbatch = input.size(0)
  1613. dim_w += 1
  1614. dim_h += 1
  1615. dim_slices += 1
  1616. pad_l, pad_r, pad_t, pad_b = padding
  1617. nplane = input.size(dim_slices)
  1618. input_h = input.size(dim_h)
  1619. input_w = input.size(dim_w)
  1620. output_h = input_h + pad_t + pad_b
  1621. output_w = input_w + pad_l + pad_r
  1622. if is_reflection:
  1623. torch._check(
  1624. pad_l < input_w and pad_r < input_w,
  1625. lambda: (
  1626. f"Argument #4: Padding size should be less than the corresponding input dimension, "
  1627. f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
  1628. ),
  1629. )
  1630. torch._check(
  1631. pad_t < input_h and pad_b < input_h,
  1632. lambda: (
  1633. f"Argument #6: Padding size should be less than the corresponding input dimension, "
  1634. f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
  1635. ),
  1636. )
  1637. torch._check(
  1638. output_w >= 1 or output_h >= 1,
  1639. lambda: (
  1640. f"input (H: {input_h} W: {input_w}) is too small. "
  1641. f"Calculated output H: {output_h} W: {output_w}"
  1642. ),
  1643. )
  1644. if input.ndim == 3:
  1645. return input.new_empty((nplane, output_h, output_w))
  1646. else:
  1647. return input.new_empty((nbatch, nplane, output_h, output_w))
  1648. @register_meta(aten.reflection_pad2d)
  1649. @out_wrapper()
  1650. def meta_reflection_pad2d(input, padding):
  1651. return _pad2d_common(input, padding, is_reflection=True)
  1652. @register_meta(aten.replication_pad2d)
  1653. @out_wrapper()
  1654. def meta_replication_pad2d(input, padding):
  1655. torch._check(
  1656. input.dtype != torch.bool,
  1657. lambda: f""""replication_pad2d" not implemented for '{input.dtype.__str__()}'""",
  1658. )
  1659. return _pad2d_common(input, padding, is_reflection=False)
  1660. @register_meta(
  1661. aten._weight_norm_interface_backward.default,
  1662. )
  1663. def meta_weight_norm_backward(grad_w, saved_v, saved_g, saved_norms, dim):
  1664. grad_v = torch.empty_like(saved_v)
  1665. grad_g = torch.empty_like(saved_g)
  1666. return grad_v, grad_g
  1667. @register_meta(
  1668. [
  1669. aten.reflection_pad2d_backward.default,
  1670. aten.reflection_pad2d_backward.grad_input,
  1671. aten.replication_pad2d_backward.default,
  1672. aten.replication_pad2d_backward.grad_input,
  1673. ]
  1674. )
  1675. @out_wrapper("grad_input")
  1676. def meta_pad2d_backward(grad_output, self, padding):
  1677. dim_w = 2
  1678. dim_h = 1
  1679. dim_plane = 0
  1680. self_shape = self.shape
  1681. if self.dim() == 4:
  1682. dim_w += 1
  1683. dim_h += 1
  1684. dim_plane += 1
  1685. pad_l, pad_r, pad_t, pad_b = padding
  1686. input_h = self_shape[dim_h]
  1687. input_w = self_shape[dim_w]
  1688. output_h = input_h + pad_t + pad_b
  1689. output_w = input_w + pad_l + pad_r
  1690. torch._check(
  1691. output_w == grad_output.size(dim_w),
  1692. lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
  1693. )
  1694. torch._check(
  1695. output_h == grad_output.size(dim_h),
  1696. lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
  1697. )
  1698. return self.new_empty(self.shape)
  1699. def _pad3d_common(input, padding, *, is_reflection):
  1700. dim_w = 3
  1701. dim_h = 2
  1702. dim_d = 1
  1703. dim_plane = 0
  1704. _padding_check_valid_input(input, padding, dim=3)
  1705. batch_mode = input.ndim == 5
  1706. if batch_mode:
  1707. nbatch = input.size(0)
  1708. dim_w += 1
  1709. dim_h += 1
  1710. dim_d += 1
  1711. dim_plane += 1
  1712. pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
  1713. nplane = input.size(dim_plane)
  1714. input_d = input.size(dim_d)
  1715. input_h = input.size(dim_h)
  1716. input_w = input.size(dim_w)
  1717. output_d = input_d + pad_f + pad_bk
  1718. output_h = input_h + pad_t + pad_b
  1719. output_w = input_w + pad_l + pad_r
  1720. if is_reflection:
  1721. torch._check(
  1722. pad_l < input_w and pad_r < input_w,
  1723. lambda: (
  1724. f"Argument #4: Padding size should be less than the corresponding input dimension, "
  1725. f"but got: padding ({pad_l}, {pad_r}) at dimension {dim_w} of input {input.shape}"
  1726. ),
  1727. )
  1728. torch._check(
  1729. pad_t < input_h and pad_b < input_h,
  1730. lambda: (
  1731. f"Argument #6: Padding size should be less than the corresponding input dimension, "
  1732. f"but got: padding ({pad_t}, {pad_b}) at dimension {dim_h} of input {input.shape}"
  1733. ),
  1734. )
  1735. torch._check(
  1736. pad_f < input_d and pad_bk < input_d,
  1737. lambda: (
  1738. f"Argument #8: Padding size should be less than the corresponding input dimension, "
  1739. f"but got: padding ({pad_f}, {pad_bk}) at dimension {dim_d} of input {input.shape}"
  1740. ),
  1741. )
  1742. torch._check(
  1743. output_w >= 1 or output_h >= 1 or output_d >= 1,
  1744. lambda: (
  1745. f"input (D: {input_d} H: {input_h} W: {input_w}) is too small. "
  1746. f"Calculated output D: {output_d} H: {output_h} W: {output_w}"
  1747. ),
  1748. )
  1749. if batch_mode:
  1750. return input.new_empty((nbatch, nplane, output_d, output_h, output_w)) # type: ignore[possibly-undefined]
  1751. else:
  1752. return input.new_empty((nplane, output_d, output_h, output_w))
  1753. @register_meta(aten.reflection_pad3d)
  1754. @out_wrapper()
  1755. def meta_reflection_pad3d(input, padding):
  1756. return _pad3d_common(input, padding, is_reflection=True)
  1757. @register_meta(aten.replication_pad3d)
  1758. @out_wrapper()
  1759. def meta_replication_pad3d(input, padding):
  1760. torch._check(
  1761. input.dtype != torch.bool,
  1762. lambda: f""""replication_pad3d" not implemented for '{input.dtype.__str__()}'""",
  1763. )
  1764. return _pad3d_common(input, padding, is_reflection=False)
  1765. @register_meta(
  1766. [
  1767. aten.reflection_pad3d_backward.default,
  1768. aten.reflection_pad3d_backward.grad_input,
  1769. aten.replication_pad3d_backward.default,
  1770. aten.replication_pad3d_backward.grad_input,
  1771. ]
  1772. )
  1773. @out_wrapper("grad_input")
  1774. def meta_pad3d_backward(grad_output, input, padding):
  1775. torch._check(len(padding) == 6, lambda: "padding size is expected to be 6")
  1776. assert input.ndim > 3
  1777. assert grad_output.ndim == input.ndim
  1778. dim_w = 3
  1779. dim_h = 2
  1780. dim_d = 1
  1781. if input.ndim == 5:
  1782. dim_w += 1
  1783. dim_h += 1
  1784. dim_d += 1
  1785. pad_l, pad_r, pad_t, pad_b, pad_f, pad_bk = padding
  1786. input_d = input.size(dim_d)
  1787. input_h = input.size(dim_h)
  1788. input_w = input.size(dim_w)
  1789. output_d = input_d + pad_f + pad_bk
  1790. output_h = input_h + pad_t + pad_b
  1791. output_w = input_w + pad_l + pad_r
  1792. torch._check(
  1793. output_w == grad_output.size(dim_w),
  1794. lambda: f"grad_output width unexpected. Expected: {output_w}, Got: {grad_output.size(dim_w)}",
  1795. )
  1796. torch._check(
  1797. output_h == grad_output.size(dim_h),
  1798. lambda: f"grad_output height unexpected. Expected: {output_h}, Got: {grad_output.size(dim_h)}",
  1799. )
  1800. torch._check(
  1801. output_d == grad_output.size(dim_d),
  1802. lambda: f"grad_output depth unexpected. Expected: {output_d}, Got: {grad_output.size(dim_d)}",
  1803. )
  1804. return input.new_empty(input.shape)
  1805. @register_meta(aten._pdist_forward)
  1806. @out_wrapper()
  1807. def meta__pdist_forward(self: Tensor, p: float = 2) -> Tensor:
  1808. torch._check(
  1809. self.is_contiguous(), lambda: "_pdist_forward requires contiguous input"
  1810. )
  1811. n = self.size(0)
  1812. if n <= 1:
  1813. return self.new_empty([0]).to(memory_format=torch.legacy_contiguous_format) # type: ignore[call-overload]
  1814. else:
  1815. return self.new_empty((n * (n - 1) // 2,)).to(
  1816. memory_format=torch.legacy_contiguous_format
  1817. ) # type: ignore[call-overload]
  1818. @register_meta(aten._pdist_backward)
  1819. @out_wrapper()
  1820. def meta__pdist_backward(grad: Tensor, self: Tensor, p: float, pdist: Tensor) -> Tensor:
  1821. torch._check(
  1822. self.is_contiguous(), lambda: "_pdist_backward requires self to be contiguous"
  1823. )
  1824. torch._check(
  1825. pdist.is_contiguous(), lambda: "_pdist_backward requires pdist to be contiguous"
  1826. )
  1827. return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
  1828. @register_meta([aten.baddbmm.default, aten.baddbmm.out])
  1829. @out_wrapper(exact_dtype=True)
  1830. def meta_baddbmm(self, batch1, batch2, *, beta=1, alpha=1):
  1831. from torch.fx.experimental.symbolic_shapes import guard_or_true, sym_eq
  1832. dim1 = batch1.size(0)
  1833. dim2 = batch1.size(1)
  1834. dim3 = batch2.size(2)
  1835. if guard_or_true(torch.sym_not(sym_eq(self.shape, (dim1, dim2, dim3)))):
  1836. self = self.expand((dim1, dim2, dim3))
  1837. torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
  1838. torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
  1839. if not exp_config.skip_dtype_check_in_meta_registrations:
  1840. torch._check(
  1841. self.dtype == batch1.dtype == batch2.dtype,
  1842. lambda: f"Input dtypes must be the same, got: input: {self.dtype}, batch1: {batch1.dtype}, batch2: {batch2.dtype}",
  1843. )
  1844. batch1_sizes = batch1.shape
  1845. batch2_sizes = batch2.shape
  1846. bs = batch1_sizes[0]
  1847. contraction_size = batch1_sizes[2]
  1848. torch._check(
  1849. batch2_sizes[0] == bs and batch2_sizes[1] == contraction_size,
  1850. lambda: (
  1851. f"Expected size for first two dimensions of batch2 tensor to be: "
  1852. f"[{bs}, {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}]."
  1853. ),
  1854. )
  1855. return self.new_empty(self.size())
  1856. @register_meta([aten.bernoulli.default, aten.bernoulli.out])
  1857. @out_wrapper()
  1858. def meta_bernoulli(self, *, generator=None):
  1859. # https://github.com/pytorch/pytorch/issues/88612
  1860. return torch.empty_like(self, memory_format=torch.contiguous_format)
  1861. @register_meta(aten.bernoulli_.float)
  1862. def meta_bernoulli_(self, p=0.5, generator=None):
  1863. return self
  1864. @register_meta(aten.bernoulli.p)
  1865. def meta_bernoulli_p(self, p=0.5, generator=None):
  1866. # https://github.com/pytorch/pytorch/issues/88612
  1867. return torch.empty_like(self, memory_format=torch.contiguous_format)
  1868. @register_meta([aten.poisson.default, aten.poisson.out])
  1869. @out_wrapper()
  1870. def meta_poisson(self, generator=None):
  1871. return torch.empty_like(self)
  1872. @register_meta(aten._fused_moving_avg_obs_fq_helper.default)
  1873. def meta__fused_moving_avg_obs_fq_helper(
  1874. self,
  1875. observer_on,
  1876. fake_quant_on,
  1877. running_min,
  1878. running_max,
  1879. scale,
  1880. zero_point,
  1881. averaging_const,
  1882. quant_min,
  1883. quant_max,
  1884. ch_axis,
  1885. per_row_fake_quant=False,
  1886. symmetric_quant=False,
  1887. ):
  1888. torch._check(
  1889. ch_axis < self.dim(),
  1890. lambda: "Error in fused_moving_avg_obs_fake_quant_cpu: ch_axis must be < self.dim()",
  1891. )
  1892. mask = torch.empty_like(self, dtype=torch.bool)
  1893. return (torch.empty_like(self), mask)
  1894. @register_meta(aten.mm)
  1895. @out_wrapper(exact_dtype=True)
  1896. def meta_mm(a, b, out_dtype: torch.dtype | None = None):
  1897. torch._check(a.dim() == 2, lambda: "a must be 2D")
  1898. torch._check(b.dim() == 2, lambda: "b must be 2D")
  1899. N, M1 = a.shape
  1900. M2, P = b.shape
  1901. torch._check(
  1902. M1 == M2,
  1903. lambda: f"a and b must have same reduction dim, but got [{N}, {M1}] X [{M2}, {P}].",
  1904. )
  1905. if out_dtype is not None:
  1906. torch._check(
  1907. out_dtype == a.dtype
  1908. or (
  1909. out_dtype == torch.float32
  1910. and a.dtype in (torch.float16, torch.bfloat16)
  1911. ),
  1912. lambda: "out_dtype must be the same as input dtype or fp32 for fp16/bf16 inputs",
  1913. )
  1914. result_dtype = a.dtype if out_dtype is None else out_dtype
  1915. return a.new_empty((N, P), dtype=result_dtype)
  1916. def _compute_reduction_shape(self, dims, keepdim):
  1917. if keepdim:
  1918. return tuple(self.shape[i] if i not in dims else 1 for i in range(self.ndim))
  1919. return utils.compute_reduction_output_shape(self.shape, dims)
  1920. # FakeTensors (meta tensors with a device) will report device as meta
  1921. # when running meta kernels. Here, access the "fake device" of FakeTensor if it
  1922. # exists so meta kernels which have diverge per device will be more
  1923. # accurate when run with FakeTensors
  1924. def device_hint(tensor) -> "str":
  1925. if isinstance(tensor, torch._subclasses.FakeTensor):
  1926. return tensor.fake_device.type
  1927. elif (
  1928. hasattr(tensor, "device")
  1929. and hasattr(tensor.device, "type")
  1930. and tensor.device.type != "meta"
  1931. ):
  1932. return tensor.device.type
  1933. else:
  1934. return "cuda" # default to cuda
  1935. def calc_conv_nd_return_shape(
  1936. input_tensor: torch.Tensor,
  1937. weight: torch.Tensor,
  1938. stride: list[int] | int,
  1939. padding: list[int] | int,
  1940. dilation: list[int] | int,
  1941. is_transposed: bool,
  1942. groups: int,
  1943. output_padding: list[int] | int | None = None,
  1944. ):
  1945. def _formula(ln: int, p: int, d: int, k: int, s: int) -> int:
  1946. """
  1947. Formula to apply to calculate the length of some dimension of the output
  1948. See: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
  1949. Args:
  1950. ln: length of the dimension
  1951. p: padding in that dim
  1952. d: dilation in that dim
  1953. k: kernel size in that dim
  1954. s: stride in that dim
  1955. Returns:
  1956. The output length
  1957. """
  1958. return (ln + 2 * p - d * (k - 1) - 1) // s + 1
  1959. def _formula_transposed(ln: int, p: int, d: int, k: int, s: int, op: int) -> int:
  1960. """
  1961. Formula to apply to calculate the length of some dimension of the output
  1962. if transposed convolution is used.
  1963. See: https://pytorch.org/docs/stable/generated/torch.nn.ConvTranspose2d.html
  1964. Args:
  1965. ln: length of the dimension
  1966. p: padding in that dim
  1967. d: dilation in that dim
  1968. k: kernel size in that dim
  1969. s: stride in that dim
  1970. op: output padding in that dim
  1971. Returns:
  1972. The output length
  1973. """
  1974. return (ln - 1) * s - 2 * p + d * (k - 1) + op + 1
  1975. kernel_size = weight.shape[2:]
  1976. dims = input_tensor.shape[2:]
  1977. if is_transposed:
  1978. out_channels = groups * weight.shape[1]
  1979. else:
  1980. out_channels = weight.shape[0]
  1981. if weight.shape[1] * groups != input_tensor.shape[1]:
  1982. raise RuntimeError("Invalid channel dimensions")
  1983. ret_shape = [input_tensor.shape[0], out_channels]
  1984. if isinstance(stride, IntLike):
  1985. # pyrefly: ignore [bad-assignment]
  1986. stride = [stride] * len(dims)
  1987. elif len(stride) == 1:
  1988. stride = [stride[0]] * len(dims)
  1989. if isinstance(padding, IntLike):
  1990. # pyrefly: ignore [bad-assignment]
  1991. padding = [padding] * len(dims)
  1992. elif len(padding) == 1:
  1993. padding = [padding[0]] * len(dims)
  1994. if isinstance(dilation, IntLike):
  1995. # pyrefly: ignore [bad-assignment]
  1996. dilation = [dilation] * len(dims)
  1997. elif len(dilation) == 1:
  1998. dilation = [dilation[0]] * len(dims)
  1999. output_padding_list: list[int] | None = None
  2000. if output_padding:
  2001. if isinstance(output_padding, IntLike):
  2002. # pyrefly: ignore [bad-assignment]
  2003. output_padding_list = [output_padding] * len(dims)
  2004. elif len(output_padding) == 1:
  2005. output_padding_list = [output_padding[0]] * len(dims)
  2006. else:
  2007. output_padding_list = output_padding
  2008. for i in range(len(dims)):
  2009. # If output_padding is present, we are dealing with a transposed convolution
  2010. if output_padding_list:
  2011. ret_shape.append(
  2012. _formula_transposed(
  2013. dims[i],
  2014. # pyrefly: ignore [index-error]
  2015. padding[i],
  2016. # pyrefly: ignore [index-error]
  2017. dilation[i],
  2018. kernel_size[i],
  2019. # pyrefly: ignore [index-error]
  2020. stride[i],
  2021. output_padding_list[i],
  2022. )
  2023. )
  2024. else:
  2025. ret_shape.append(
  2026. # pyrefly: ignore [index-error]
  2027. _formula(dims[i], padding[i], dilation[i], kernel_size[i], stride[i])
  2028. )
  2029. from torch.fx.experimental.symbolic_shapes import sym_or
  2030. torch._check(
  2031. sym_or(*[x > 0 for x in ret_shape[2:]]),
  2032. lambda: f"Given input size per channel: {list(dims)}. "
  2033. f"Calculated output size per channel: {ret_shape[2:]}. "
  2034. f"Output size is too small",
  2035. )
  2036. return ret_shape
  2037. def is_channels_last(ten):
  2038. return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
  2039. @register_meta(aten.miopen_batch_norm.default)
  2040. def meta_miopen_batch_norm(
  2041. input_tensor: torch.Tensor,
  2042. weight: torch.Tensor,
  2043. bias: torch.Tensor | None,
  2044. running_mean: torch.Tensor | None,
  2045. running_var: torch.Tensor | None,
  2046. training: bool,
  2047. exponential_average_factor: float,
  2048. epsilon: float,
  2049. ):
  2050. # In batch norm the output is of the same shape as the input
  2051. out_shape = input_tensor.shape
  2052. # If tensor is provided for running_mean and running_var then use this. If these are not
  2053. # provided then we return the shape of weight tensor. Similar to how this is handled in the decomposition
  2054. save_mean_shape = running_mean.shape if running_mean is not None else weight.shape
  2055. save_var_shape = running_var.shape if running_var is not None else weight.shape
  2056. def pick_memory_format():
  2057. if is_channels_last(input_tensor):
  2058. return torch.channels_last
  2059. if input_tensor.is_contiguous(memory_format=torch.contiguous_format):
  2060. return torch.contiguous_format
  2061. return torch.contiguous_format
  2062. out = input_tensor.new_empty(out_shape).to(memory_format=pick_memory_format())
  2063. if training:
  2064. save_mean = input_tensor.new_empty(save_mean_shape)
  2065. save_var = input_tensor.new_empty(save_var_shape)
  2066. else:
  2067. save_mean = input_tensor.new_empty((0,))
  2068. save_var = input_tensor.new_empty((0,))
  2069. return out, save_mean, save_var
  2070. @register_meta(aten.convolution.default)
  2071. def meta_conv(
  2072. input_tensor: torch.Tensor,
  2073. weight: torch.Tensor,
  2074. bias: torch.Tensor,
  2075. stride: list[int],
  2076. padding: list[int],
  2077. dilation: list[int],
  2078. is_transposed: bool,
  2079. output_padding: list[int],
  2080. groups: int,
  2081. ):
  2082. shape_out = calc_conv_nd_return_shape(
  2083. input_tensor,
  2084. weight,
  2085. stride,
  2086. padding,
  2087. dilation,
  2088. is_transposed,
  2089. groups,
  2090. output_padding if is_transposed else None,
  2091. )
  2092. input_channels_dim = 1
  2093. output_channels_dim = 1
  2094. if input_tensor.size(input_channels_dim) == 0:
  2095. shape_out[output_channels_dim] = 0
  2096. out = input_tensor.new_empty(shape_out)
  2097. return out
  2098. if torch._C._has_mkldnn:
  2099. _meta_lib_dont_use_me_use_register_meta_for_mkldnn = torch.library.Library(
  2100. "mkldnn", "IMPL", "Meta"
  2101. )
  2102. @register_meta(torch.ops.mkldnn._convolution_pointwise.default)
  2103. def meta_mkldnn_convolution_default(
  2104. input_tensor,
  2105. weight,
  2106. bias,
  2107. padding,
  2108. stride,
  2109. dilation,
  2110. groups,
  2111. attr,
  2112. scalars,
  2113. algorithm,
  2114. ):
  2115. shape_out = calc_conv_nd_return_shape(
  2116. input_tensor, weight, stride, padding, dilation, False, groups, []
  2117. )
  2118. out = input_tensor.new_empty(shape_out)
  2119. out_memory_format = torch.channels_last
  2120. if input_tensor.dim() == 5:
  2121. out_memory_format = torch.channels_last_3d
  2122. out = out.to(memory_format=out_memory_format) # type: ignore[call-overload]
  2123. return out
  2124. @register_meta(torch.ops.mkldnn._linear_pointwise.default)
  2125. def meta_linear_pointwise_default(
  2126. input_tensor, weight, bias, attr, scalars, algorithm
  2127. ):
  2128. return input_tensor.new_empty((*input_tensor.shape[:-1], weight.shape[0]))
  2129. if torch._C.has_mkl:
  2130. _meta_lib_dont_use_me_use_register_meta_for_mkl = torch.library.Library(
  2131. "mkl", "IMPL", "Meta"
  2132. )
  2133. @register_meta(torch.ops.mkl._mkl_linear)
  2134. def meta_mkl_linear(input_tensor, packed_weight, orig_weight, bias, batch_size):
  2135. return input_tensor.new_empty(
  2136. (*input_tensor.shape[:-1], orig_weight.shape[0])
  2137. )
  2138. _meta_lib_dont_use_me_use_register_meta_for_onednn = torch.library.Library(
  2139. "onednn", "IMPL", "Meta"
  2140. )
  2141. @register_meta(torch.ops.onednn.qconv2d_pointwise.default)
  2142. @register_meta(torch.ops.onednn.qconv_pointwise.default)
  2143. @register_meta(torch.ops.onednn.qconv_pointwise.tensor)
  2144. def meta_qconv_pointwise(
  2145. x,
  2146. x_scale,
  2147. x_zp,
  2148. w, # prepacked_weight
  2149. w_scale,
  2150. w_zp,
  2151. bias,
  2152. stride,
  2153. padding,
  2154. dilation,
  2155. groups,
  2156. output_scale,
  2157. output_zero_point,
  2158. output_dtype,
  2159. attr,
  2160. scalars,
  2161. algorithm,
  2162. ):
  2163. shape_out = calc_conv_nd_return_shape(
  2164. x,
  2165. w,
  2166. stride,
  2167. padding,
  2168. dilation,
  2169. False,
  2170. groups,
  2171. None,
  2172. )
  2173. if output_dtype is None:
  2174. output_dtype = x.dtype
  2175. assert output_dtype in [
  2176. torch.float32,
  2177. torch.bfloat16,
  2178. torch.uint8,
  2179. torch.int8,
  2180. torch.float8_e4m3fn,
  2181. ]
  2182. out = x.new_empty(shape_out, dtype=output_dtype)
  2183. assert len(shape_out) in [3, 4, 5], (
  2184. "Expect output to be 3d/4d/5d for conv1d/2d/3d"
  2185. )
  2186. format = {
  2187. 3: torch.contiguous_format,
  2188. 4: torch.channels_last,
  2189. 5: torch.channels_last_3d,
  2190. }[len(shape_out)]
  2191. out = out.to(memory_format=format)
  2192. return out
  2193. @register_meta(torch.ops.onednn.qconv2d_pointwise.binary)
  2194. @register_meta(torch.ops.onednn.qconv2d_pointwise.binary_tensor)
  2195. def meta_qconv2d_pointwise_binary(
  2196. x,
  2197. x_scale,
  2198. x_zp,
  2199. w,
  2200. w_scale,
  2201. w_zp,
  2202. accum,
  2203. bias,
  2204. stride,
  2205. padding,
  2206. dilation,
  2207. groups,
  2208. output_scale,
  2209. output_zero_point,
  2210. output_dtype,
  2211. accum_scale,
  2212. accum_zero_point,
  2213. binary_op_name,
  2214. alpha,
  2215. unary_op_name,
  2216. unary_op_args,
  2217. unary_op_algorithm,
  2218. ):
  2219. assert binary_op_name == "sum"
  2220. return accum
  2221. @register_meta(torch.ops.onednn.qlinear_pointwise.default)
  2222. @register_meta(torch.ops.onednn.qlinear_pointwise.tensor)
  2223. def meta_qlinear_pointwise(
  2224. x,
  2225. x_scale,
  2226. x_zp,
  2227. w,
  2228. w_scale,
  2229. w_zp,
  2230. bias,
  2231. output_scale,
  2232. output_zero_point,
  2233. output_dtype,
  2234. post_op_name,
  2235. post_op_args,
  2236. post_op_algorithm,
  2237. ):
  2238. output_shape = list(x.shape)
  2239. # The weight has been transposed during the qlinear weight prepack process.
  2240. output_shape[-1] = w.shape[1]
  2241. assert output_dtype in [
  2242. torch.float32,
  2243. torch.bfloat16,
  2244. torch.int8,
  2245. torch.uint8,
  2246. torch.float8_e4m3fn,
  2247. ]
  2248. out = x.new_empty(output_shape, dtype=output_dtype)
  2249. return out
  2250. @register_meta(torch.ops.onednn.qlinear_pointwise.binary)
  2251. @register_meta(torch.ops.onednn.qlinear_pointwise.binary_tensor)
  2252. def meta_qlinear_pointwise_binary(
  2253. x,
  2254. x_scale,
  2255. x_zp,
  2256. w,
  2257. w_scale,
  2258. w_zp,
  2259. x_2,
  2260. bias,
  2261. output_scale,
  2262. output_zero_point,
  2263. output_dtype,
  2264. x2_scale,
  2265. x2_zp,
  2266. binary_op_name,
  2267. alpha,
  2268. unary_op_name,
  2269. unary_op_args,
  2270. unary_op_algorithm,
  2271. ):
  2272. if binary_op_name == "sum":
  2273. return x_2
  2274. output_shape = list(x.shape)
  2275. # The weight has been transposed during the qlinear weight prepack process.
  2276. output_shape[-1] = w.shape[1]
  2277. assert output_dtype in [
  2278. torch.float32,
  2279. torch.bfloat16,
  2280. torch.uint8,
  2281. torch.int8,
  2282. torch.float8_e4m3fn,
  2283. ]
  2284. out = x.new_empty(output_shape, dtype=output_dtype)
  2285. return out
  2286. @register_meta(torch.ops.onednn.linear_dynamic_fp16.default)
  2287. @register_meta(torch.ops.onednn.linear_relu_dynamic_fp16.default)
  2288. def meta_linear_dynamic_fp16(
  2289. x,
  2290. w,
  2291. bias,
  2292. ):
  2293. output_shape = list(x.shape)
  2294. # The weight has been transposed during the qlinear weight prepack process.
  2295. output_shape[-1] = w.shape[1]
  2296. out = x.new_empty(output_shape)
  2297. return out
  2298. _meta_lib_dont_use_me_use_register_meta_for_quantized = torch.library.Library(
  2299. "quantized", "IMPL", "Meta"
  2300. )
  2301. @register_meta(torch.ops.quantized.max_pool2d)
  2302. def meta_quantized_max_pool2d(
  2303. input,
  2304. kernel_size,
  2305. stride=(),
  2306. padding=(0,),
  2307. dilation=(1,),
  2308. ceil_mode=False,
  2309. ):
  2310. (
  2311. nInputPlane,
  2312. outputHeight,
  2313. outputWidth,
  2314. ) = max_pool2d_checks_and_compute_shape(
  2315. input, kernel_size, stride, padding, dilation, ceil_mode
  2316. )
  2317. nbatch = input.size(-4) if input.dim() == 4 else 1
  2318. memory_format = torch.channels_last
  2319. if input.dim() == 3:
  2320. size = [nInputPlane, outputHeight, outputWidth]
  2321. else:
  2322. size = [nbatch, nInputPlane, outputHeight, outputWidth]
  2323. return torch.empty(
  2324. size,
  2325. dtype=input.dtype,
  2326. device=input.device,
  2327. memory_format=memory_format,
  2328. )
  2329. @register_meta(torch.ops.quantized.int4mm_packed_weight_cpu)
  2330. def meta_int4mm_packed_weight_cpu(x, w, q_group_size, q_scale_and_zeros):
  2331. torch._check(x.dim() == 2, lambda: f"x must be a 2D tensor, got {x.dim()}D")
  2332. torch._check(w.dim() == 2, lambda: f"w must be a 2D tensor, got {w.dim()}D")
  2333. torch._check(
  2334. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  2335. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  2336. )
  2337. torch._check(
  2338. w.dtype == torch.uint8, lambda: f"expected w to be uint8, got {w.dtype}"
  2339. )
  2340. torch._check(
  2341. q_group_size.dtype == torch.int64,
  2342. lambda: f"q_group_size must be int64, got {q_group_size.dtype}",
  2343. )
  2344. torch._check(
  2345. q_scale_and_zeros.dtype == x.dtype,
  2346. lambda: f"q_scale_and_zeros must have the same dtype as x, got {q_scale_and_zeros.dtype}",
  2347. )
  2348. return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
  2349. # from check_dim_size() in aten/src/ATen/TensorUtils.cpp.
  2350. def check_dim_size(tensor, dim, dim_size, size):
  2351. torch._check(
  2352. tensor.dim() == dim and tensor.shape[dim_size] == size,
  2353. lambda: f"Expected a tensor of dimension {dim} and tensor.size[{dim_size}] == {size}, "
  2354. + f"but got : dimension {tensor.dim()} and tensor.size[{dim_size}] = {tensor.shape[dim_size]}",
  2355. )
  2356. @register_meta(aten.avg_pool2d.default)
  2357. def meta_avg_pool2d(
  2358. input,
  2359. kernel_size,
  2360. stride=(),
  2361. padding=(0,),
  2362. ceil_mode=False,
  2363. count_include_pad=True,
  2364. divisor_override=None,
  2365. ):
  2366. def unpack(name, val):
  2367. torch._check(
  2368. len(val) in [1, 2],
  2369. lambda: f"avg_pool2d: {name} must either be a single int, or a tuple of two ints",
  2370. )
  2371. H = val[0]
  2372. W = H if len(val) == 1 else val[1]
  2373. return H, W
  2374. kH, kW = unpack("kernel_size", kernel_size)
  2375. torch._check(
  2376. len(stride) in [0, 1, 2],
  2377. lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  2378. )
  2379. torch._check(
  2380. input.dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64],
  2381. lambda: f""""avg_pool2d" not implemented for '{input.dtype.__str__()}'""",
  2382. )
  2383. if len(stride) == 0:
  2384. dH, dW = kH, kW
  2385. elif len(stride) == 1:
  2386. dH, dW = stride[0], stride[0]
  2387. else:
  2388. dH, dW = unpack("stride", stride)
  2389. padH, padW = unpack("padding", padding)
  2390. torch._check(
  2391. divisor_override is None or divisor_override != 0,
  2392. lambda: "divisor must be not zero",
  2393. )
  2394. nbatch = input.size(-4) if input.dim() == 4 else 1
  2395. nInputPlane = input.size(-3)
  2396. inputHeight = input.size(-2)
  2397. inputWidth = input.size(-1)
  2398. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
  2399. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
  2400. memory_format = utils.suggest_memory_format(input)
  2401. pool2d_shape_check(
  2402. input,
  2403. kH,
  2404. kW,
  2405. dH,
  2406. dW,
  2407. padH,
  2408. padW,
  2409. 1,
  2410. 1,
  2411. nInputPlane,
  2412. inputHeight,
  2413. inputWidth,
  2414. outputHeight,
  2415. outputWidth,
  2416. memory_format,
  2417. )
  2418. if input.dim() == 3:
  2419. size = [nInputPlane, outputHeight, outputWidth]
  2420. else:
  2421. size = [nbatch, nInputPlane, outputHeight, outputWidth]
  2422. return torch.empty(
  2423. size,
  2424. dtype=input.dtype,
  2425. device=input.device,
  2426. memory_format=memory_format,
  2427. )
  2428. # from avg_pool2d_backward_shape_check() in aten/src/ATen/native/Pool.h.
  2429. def avg_pool2d_backward_shape_check(
  2430. input,
  2431. gradOutput,
  2432. nbatch,
  2433. kH,
  2434. kW,
  2435. dH,
  2436. dW,
  2437. padH,
  2438. padW,
  2439. nInputPlane,
  2440. inputHeight,
  2441. inputWidth,
  2442. outputHeight,
  2443. outputWidth,
  2444. mem_format,
  2445. ):
  2446. pool2d_shape_check(
  2447. input,
  2448. kH,
  2449. kW,
  2450. dH,
  2451. dW,
  2452. padH,
  2453. padW,
  2454. 1,
  2455. 1,
  2456. nInputPlane,
  2457. inputHeight,
  2458. inputWidth,
  2459. outputHeight,
  2460. outputWidth,
  2461. mem_format,
  2462. )
  2463. ndim = input.dim()
  2464. nOutputPlane = nInputPlane
  2465. check_dim_size(gradOutput, ndim, ndim - 3, nOutputPlane)
  2466. check_dim_size(gradOutput, ndim, ndim - 2, outputHeight)
  2467. check_dim_size(gradOutput, ndim, ndim - 1, outputWidth)
  2468. # Don't override the C++ registration.
  2469. @register_meta(aten.avg_pool2d_backward.default)
  2470. def meta_avg_pool2d_backward(
  2471. gradOutput_,
  2472. input,
  2473. kernel_size,
  2474. stride,
  2475. padding,
  2476. ceil_mode,
  2477. count_include_pad,
  2478. divisor_override,
  2479. ):
  2480. # From aten/src/ATen/native/AveragePool2d.cpp structured kernel meta func.
  2481. torch._check(
  2482. len(kernel_size) == 1 or len(kernel_size) == 2,
  2483. lambda: "avg_pool2d: kernel_size must either be a single int, or a tuple of two ints",
  2484. )
  2485. kH = kernel_size[0]
  2486. kW = kH if len(kernel_size) == 1 else kernel_size[1]
  2487. torch._check(
  2488. len(stride) == 0 or len(stride) == 1 or len(stride) == 2,
  2489. lambda: "avg_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  2490. )
  2491. dH = kH if len(stride) == 0 else stride[0]
  2492. dW = kW if len(stride) == 0 else dH if len(stride) == 1 else stride[1]
  2493. torch._check(
  2494. len(padding) == 1 or len(padding) == 2,
  2495. lambda: "avg_pool2d: padding must either be a single int, or a tuple of two ints",
  2496. )
  2497. padH = padding[0]
  2498. padW = padH if len(padding) == 1 else padding[1]
  2499. torch._check(
  2500. divisor_override is None or divisor_override != 0,
  2501. lambda: "divisor must be not zero",
  2502. )
  2503. input_size = input.shape
  2504. nbatch = input_size[-4] if input.dim() == 4 else 1
  2505. nInputPlane = input_size[-3]
  2506. inputHeight = input_size[-2]
  2507. inputWidth = input_size[-1]
  2508. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, 1, ceil_mode)
  2509. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, 1, ceil_mode)
  2510. mem_format = utils.suggest_memory_format(input)
  2511. avg_pool2d_backward_shape_check(
  2512. input,
  2513. gradOutput_,
  2514. nbatch,
  2515. kH,
  2516. kW,
  2517. dH,
  2518. dW,
  2519. padH,
  2520. padW,
  2521. nInputPlane,
  2522. inputHeight,
  2523. inputWidth,
  2524. outputHeight,
  2525. outputWidth,
  2526. mem_format,
  2527. )
  2528. return torch.empty(
  2529. input_size,
  2530. dtype=input.dtype,
  2531. device=input.device,
  2532. memory_format=mem_format,
  2533. )
  2534. @register_meta(aten.avg_pool3d)
  2535. @out_wrapper()
  2536. def meta_avg_pool3d(
  2537. input,
  2538. kernel_size,
  2539. stride=(),
  2540. padding=(0,),
  2541. ceil_mode=False,
  2542. count_include_pad=True,
  2543. divisor_override=None,
  2544. ):
  2545. torch._check(
  2546. len(kernel_size) in (1, 3),
  2547. lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
  2548. )
  2549. kT = kernel_size[0]
  2550. kH = kT if len(kernel_size) == 1 else kernel_size[1]
  2551. kW = kT if len(kernel_size) == 1 else kernel_size[2]
  2552. torch._check(
  2553. not stride or len(stride) in (1, 3),
  2554. lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
  2555. )
  2556. torch._check(
  2557. input.dtype not in [torch.uint8, torch.uint16, torch.uint32, torch.uint64],
  2558. lambda: f""""avg_pool3d" not implemented for '{input.dtype.__str__()}'""",
  2559. )
  2560. dT = kT if not stride else stride[0]
  2561. dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
  2562. dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
  2563. torch._check(
  2564. len(padding) in (1, 3),
  2565. lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
  2566. )
  2567. padT = padding[0]
  2568. padH = padT if len(padding) == 1 else padding[1]
  2569. padW = padT if len(padding) == 1 else padding[2]
  2570. torch._check(
  2571. input.ndim in (4, 5),
  2572. lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
  2573. )
  2574. torch._check(
  2575. not divisor_override or divisor_override != 0,
  2576. lambda: "divisor must be not zero",
  2577. )
  2578. nbatch = input.size(0)
  2579. nslices = input.size(-4)
  2580. itime = input.size(-3)
  2581. iheight = input.size(-2)
  2582. iwidth = input.size(-1)
  2583. otime = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
  2584. oheight = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
  2585. owidth = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
  2586. pool3d_shape_check(
  2587. input,
  2588. nslices,
  2589. kT,
  2590. kH,
  2591. kW,
  2592. dT,
  2593. dH,
  2594. dW,
  2595. padT,
  2596. padH,
  2597. padW,
  2598. 1,
  2599. 1,
  2600. 1,
  2601. itime,
  2602. iheight,
  2603. iwidth,
  2604. otime,
  2605. oheight,
  2606. owidth,
  2607. "avg_pool3d()",
  2608. check_input_size=True,
  2609. )
  2610. if input.ndim == 4:
  2611. return input.new_empty((nslices, otime, oheight, owidth))
  2612. else:
  2613. return input.new_empty((nbatch, nslices, otime, oheight, owidth))
  2614. @register_meta(aten.avg_pool3d_backward)
  2615. @out_wrapper("grad_input")
  2616. def meta_avg_pool3d_backward(
  2617. grad_output,
  2618. input,
  2619. kernel_size,
  2620. stride,
  2621. padding,
  2622. ceil_mode,
  2623. count_include_pad,
  2624. divisor_override,
  2625. ):
  2626. torch._check(
  2627. len(kernel_size) in (1, 3),
  2628. lambda: "avg_pool3d: kernel_size must be a single int, or a tuple of three ints",
  2629. )
  2630. kT = kernel_size[0]
  2631. kH = kT if len(kernel_size) == 1 else kernel_size[1]
  2632. kW = kT if len(kernel_size) == 1 else kernel_size[2]
  2633. torch._check(
  2634. not stride or len(stride) in (1, 3),
  2635. lambda: "avg_pool3d: stride must be omitted, a single int, or a tuple of three ints",
  2636. )
  2637. dT = kT if not stride else stride[0]
  2638. dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
  2639. dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
  2640. torch._check(
  2641. len(padding) in (1, 3),
  2642. lambda: "avg_pool3d: padding must be a single int, or a tuple of three ints",
  2643. )
  2644. padT = padding[0]
  2645. padH = padT if len(padding) == 1 else padding[1]
  2646. padW = padT if len(padding) == 1 else padding[2]
  2647. torch._check(
  2648. input.ndim in (4, 5),
  2649. lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
  2650. )
  2651. torch._check(
  2652. not divisor_override or divisor_override != 0,
  2653. lambda: "divisor must be not zero",
  2654. )
  2655. nslices = input.size(-4)
  2656. itime = input.size(-3)
  2657. iheight = input.size(-2)
  2658. iwidth = input.size(-1)
  2659. otime_for_shape_check = pooling_output_shape(itime, kT, padT, dT, 1, ceil_mode)
  2660. oheight_for_shape_check = pooling_output_shape(iheight, kH, padH, dH, 1, ceil_mode)
  2661. owidth_for_shape_check = pooling_output_shape(iwidth, kW, padW, dW, 1, ceil_mode)
  2662. avg_pool3d_backward_shape_check(
  2663. input,
  2664. grad_output,
  2665. nslices,
  2666. kT,
  2667. kH,
  2668. kW,
  2669. dT,
  2670. dH,
  2671. dW,
  2672. padT,
  2673. padH,
  2674. padW,
  2675. itime,
  2676. iheight,
  2677. iwidth,
  2678. otime_for_shape_check,
  2679. oheight_for_shape_check,
  2680. owidth_for_shape_check,
  2681. "avg_pool3d_backward()",
  2682. )
  2683. return input.new_empty(input.shape)
  2684. @register_meta(aten._adaptive_avg_pool2d.default)
  2685. def meta_adaptive_avg_pool2d(self, output_size):
  2686. torch._check(
  2687. self.ndim == 3 or self.ndim == 4,
  2688. lambda: f"Expected 3D or 4D tensor, but got {self.shape}",
  2689. )
  2690. output_shape = self.shape[:-2] + tuple(output_size)
  2691. memory_format = utils.suggest_memory_format(self)
  2692. # need to set memory_format to preserve the memory format of the input
  2693. # channel last input should have channel last output
  2694. return torch.empty(
  2695. output_shape,
  2696. dtype=self.dtype,
  2697. device=self.device,
  2698. memory_format=memory_format,
  2699. )
  2700. @register_meta(aten._adaptive_avg_pool3d.default)
  2701. def meta_adaptive_avg_pool3d(self, output_size):
  2702. torch._check(
  2703. self.ndim == 4 or self.ndim == 5,
  2704. lambda: f"Expected 4D or 5D tensor, but got {self.shape}",
  2705. )
  2706. return self.new_empty(self.shape[:-3] + tuple(output_size))
  2707. @register_meta(aten._adaptive_avg_pool2d_backward.default)
  2708. def meta__adaptive_avg_pool2d_backward(grad_out, self):
  2709. ndim = grad_out.ndim
  2710. for i in range(1, ndim):
  2711. torch._check(
  2712. grad_out.size(i) > 0,
  2713. lambda: f"adaptive_avg_pool2d_backward(): Expected grad_output to have non-zero \
  2714. size for non-batch dimensions, {grad_out.shape} with dimension {i} being empty",
  2715. )
  2716. torch._check(
  2717. ndim == 3 or ndim == 4,
  2718. lambda: f"adaptive_avg_pool2d_backward(): Expected 3D or 4D tensor, but got {self.shape}",
  2719. )
  2720. torch._check(
  2721. self.dtype == grad_out.dtype,
  2722. lambda: f"expected dtype {self.dtype} for `grad_output` but got dtype {grad_out.dtype}",
  2723. )
  2724. memory_format = torch.contiguous_format
  2725. if is_channels_last(self):
  2726. memory_format = torch.channels_last
  2727. return self.new_empty(self.shape).to(memory_format=memory_format)
  2728. @register_meta(aten._adaptive_avg_pool3d_backward)
  2729. @out_wrapper("grad_input")
  2730. def meta__adaptive_avg_pool3d_backward(grad_output, self):
  2731. _adaptive_pool_empty_output_check(grad_output, "adaptive_avg_pool3d_backward")
  2732. return torch.empty_like(self, memory_format=torch.legacy_contiguous_format)
  2733. def _adaptive_pool_empty_output_check(grad_output: Tensor, arg_name: str):
  2734. ndim = grad_output.ndim
  2735. for i in range(1, ndim):
  2736. torch._check(
  2737. grad_output.size(i) > 0,
  2738. lambda: (
  2739. f"{arg_name}(): Expected grad_output to have non-zero size for non-batch dimensions, "
  2740. f"but grad_output has sizes {grad_output.shape} with dimension {i} being empty"
  2741. ),
  2742. )
  2743. @register_meta(aten.adaptive_max_pool2d)
  2744. @out_wrapper("out", "indices")
  2745. def meta_adaptive_max_pool2d(input, output_size):
  2746. ndim = input.ndim
  2747. torch._check(
  2748. ndim in (3, 4),
  2749. lambda: f"adaptive_max_pool2d(): Expected 3D or 4D tensor, but got: {input.shape}",
  2750. )
  2751. for i in range(1, ndim):
  2752. torch._check(
  2753. input.size(i) > 0,
  2754. lambda: (
  2755. f"adaptive_max_pool2d(): Expected input to have non-zero size for non-batch dimensions, "
  2756. f"but input has sizes {input.shape} with dimension {i} being empty"
  2757. ),
  2758. )
  2759. torch._check(
  2760. len(output_size) == 2,
  2761. lambda: "adaptive_max_pool2d(): internal error: output_size.size() must be 2",
  2762. )
  2763. dimH = 1
  2764. sizeB = 1
  2765. sizeD = 0
  2766. if input.ndim == 4:
  2767. sizeB = input.size(0)
  2768. dimH += 1
  2769. sizeD = input.size(dimH - 1)
  2770. osizeH, osizeW = output_size
  2771. if input.ndim == 3:
  2772. out_shape = (sizeD, osizeH, osizeW)
  2773. out = input.new_empty(out_shape)
  2774. indices = input.new_empty(out_shape, dtype=torch.int64)
  2775. return out, indices
  2776. else:
  2777. out_shape = (sizeB, sizeD, osizeH, osizeW) # type: ignore[assignment]
  2778. memory_format = utils.suggest_memory_format(input)
  2779. out = input.new_empty(out_shape).to(memory_format=memory_format)
  2780. indices = input.new_empty(out_shape, dtype=torch.int64).to(
  2781. memory_format=memory_format
  2782. )
  2783. return out, indices
  2784. @register_meta(aten.adaptive_max_pool2d_backward)
  2785. @out_wrapper("grad_input")
  2786. def meta_adaptive_max_pool2d_backward(grad_output, input, indices):
  2787. ndim = grad_output.ndim
  2788. torch._check(
  2789. ndim in (3, 4),
  2790. lambda: f"adaptive_max_pooling2d_backward(): Expected 3D or 4D grad_output, but got: {grad_output.shape}",
  2791. )
  2792. _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool2d_backward")
  2793. torch._check(
  2794. input.dtype == grad_output.dtype,
  2795. lambda: f"expected dtype {input.dtype} for `grad_output` but got dtype {grad_output.dtype}",
  2796. )
  2797. memory_format = utils.suggest_memory_format(input)
  2798. return input.new_empty(input.shape).to(memory_format=memory_format)
  2799. @register_meta(aten.adaptive_max_pool3d)
  2800. @out_wrapper("out", "indices")
  2801. def meta_adaptive_max_pool3d(input, output_size):
  2802. ndim = input.ndim
  2803. torch._check(
  2804. ndim in (4, 5),
  2805. lambda: f"adaptive_max_pool3d(): Expected 4D or 5D tensor, but got: {input.shape}",
  2806. )
  2807. for i in range(1, ndim):
  2808. torch._check(
  2809. input.size(i) > 0,
  2810. lambda: (
  2811. f"adaptive_max_pool3d(): Expected input to have non-zero size for non-batch dimensions, "
  2812. f"but input has sizes {input.shape} with dimension {i} being empty"
  2813. ),
  2814. )
  2815. torch._check(
  2816. len(output_size) == 3,
  2817. lambda: "adaptive_max_pool3d(): internal error: output_size.size() must be 3",
  2818. )
  2819. dimD = 0
  2820. sizeB = 1
  2821. sizeD = 0
  2822. if ndim == 5:
  2823. sizeB = input.size(0)
  2824. dimD += 1
  2825. sizeD = input.size(dimD)
  2826. osizeT, osizeH, osizeW = output_size
  2827. if ndim == 4:
  2828. out_shape = (sizeD, osizeT, osizeH, osizeW)
  2829. else:
  2830. out_shape = (sizeB, sizeD, osizeT, osizeH, osizeW) # type: ignore[assignment]
  2831. out = input.new_empty(out_shape)
  2832. indices = input.new_empty(out_shape, dtype=torch.int64)
  2833. return out, indices
  2834. @register_meta(aten.adaptive_max_pool3d_backward)
  2835. @out_wrapper("grad_input")
  2836. def meta_adaptive_max_pool3d_backward(grad_output, input, indices):
  2837. _adaptive_pool_empty_output_check(grad_output, "adaptive_max_pool3d_backward")
  2838. return input.new_empty(input.shape)
  2839. @register_meta(aten.repeat_interleave.Tensor)
  2840. def meta_repeat_interleave_Tensor(repeats, output_size=None):
  2841. if output_size is None:
  2842. raise RuntimeError("cannot repeat_interleave a meta tensor without output_size")
  2843. return repeats.new_empty(output_size)
  2844. @register_meta([aten.complex.default, aten.complex.out])
  2845. @out_wrapper()
  2846. def meta_complex(real, imag):
  2847. assert real.dtype.is_floating_point
  2848. assert imag.dtype.is_floating_point
  2849. result = elementwise_meta(
  2850. real.to(corresponding_complex_dtype(real.dtype)),
  2851. imag.to(corresponding_complex_dtype(imag.dtype)),
  2852. type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT,
  2853. )
  2854. return result
  2855. @register_meta([aten.nonzero_static.default, aten.nonzero_static.out])
  2856. @out_wrapper()
  2857. def nonzero_static(self, *, size, fill_value: int = -1):
  2858. return self.new_empty((size, self.dim()), dtype=torch.long)
  2859. @register_meta([torch.ops.aten.nonzero.default, torch.ops.aten.nonzero.out])
  2860. @out_wrapper()
  2861. def nonzero(self):
  2862. torch._check_not_implemented(
  2863. exp_config.meta_nonzero_assume_all_nonzero,
  2864. lambda: "The register_meta function for torch.nonzero() raises unimplemented by default, "
  2865. "as a correct data-independent implementation does not exist. This implementation "
  2866. "returns a fake value, assuming all elements of the tensor are non-zero. "
  2867. "To enable this registration, please set "
  2868. "'torch.fx.experimental._config.meta_nonzero_assume_all_nonzero' to True.",
  2869. )
  2870. return torch.empty_strided(
  2871. (self.numel(), self.dim()),
  2872. (1, self.numel()),
  2873. dtype=torch.long,
  2874. device=self.device,
  2875. )
  2876. @register_meta([aten.index.Tensor, aten._unsafe_index.Tensor])
  2877. def meta_index_Tensor(self, indices):
  2878. torch._check(bool(indices), lambda: "at least one index must be provided")
  2879. # aten::index is the internal advanced indexing implementation
  2880. # checkIndexTensorTypes and expandTensors
  2881. result: list[Tensor | None] = []
  2882. for i, index in enumerate(indices):
  2883. if index is not None:
  2884. torch._check(
  2885. index.dtype in [torch.long, torch.int, torch.int8, torch.bool],
  2886. lambda: "tensors used as indices must be long, int, byte or bool tensors",
  2887. )
  2888. if index.dtype in [torch.int8, torch.bool]:
  2889. nonzero = index.nonzero()
  2890. k = len(result)
  2891. torch._check_index(
  2892. k + index.ndim <= self.ndim,
  2893. lambda: f"too many indices for tensor of dimension {self.ndim}",
  2894. )
  2895. for j in range(index.ndim):
  2896. torch._check_index(
  2897. index.shape[j] == self.shape[k + j],
  2898. lambda: f"The shape of the mask {index.shape} at index {i} "
  2899. f"does not match the shape of the indexed tensor {self.shape} at index {k + j}",
  2900. )
  2901. result.append(nonzero.select(1, j))
  2902. else:
  2903. result.append(index)
  2904. else:
  2905. result.append(index)
  2906. indices = result
  2907. torch._check(
  2908. len(indices) <= self.ndim,
  2909. lambda: f"too many indices for tensor of dimension {self.ndim} (got {len(indices)})",
  2910. )
  2911. # expand_outplace
  2912. import torch._refs as refs # avoid import cycle in mypy
  2913. indices = list(refs._maybe_broadcast(*indices))
  2914. # add missing null tensors
  2915. while len(indices) < self.ndim:
  2916. indices.append(None)
  2917. # hasContiguousSubspace
  2918. # true if all non-null tensors are adjacent
  2919. # See:
  2920. # https://numpy.org/doc/stable/user/basics.indexing.html#combining-advanced-and-basic-indexing
  2921. # https://stackoverflow.com/questions/53841497/why-does-numpy-mixed-basic-advanced-indexing-depend-on-slice-adjacency
  2922. state = 0
  2923. has_contiguous_subspace = False
  2924. for index in indices:
  2925. if state == 0:
  2926. if index is not None:
  2927. state = 1
  2928. elif state == 1:
  2929. if index is None:
  2930. state = 2
  2931. else:
  2932. if index is not None:
  2933. break
  2934. else:
  2935. has_contiguous_subspace = True
  2936. # transposeToFront
  2937. # This is the logic that causes the newly inserted dimensions to show up
  2938. # at the beginning of the tensor, if they're not contiguous
  2939. if not has_contiguous_subspace:
  2940. dims = []
  2941. transposed_indices = []
  2942. for i, index in enumerate(indices):
  2943. if index is not None:
  2944. dims.append(i)
  2945. transposed_indices.append(index)
  2946. for i, index in enumerate(indices):
  2947. if index is None:
  2948. dims.append(i)
  2949. transposed_indices.append(index)
  2950. self = self.permute(dims)
  2951. indices = transposed_indices
  2952. # AdvancedIndex::AdvancedIndex
  2953. # Now we can assume the indices have contiguous subspace
  2954. # This is simplified from AdvancedIndex which goes to more effort
  2955. # to put the input and indices in a form so that TensorIterator can
  2956. # take them. If we write a ref for this, probably that logic should
  2957. # get implemented
  2958. before_shape: list[int] = []
  2959. after_shape: list[int] = []
  2960. replacement_shape: list[int] = []
  2961. for dim, index in enumerate(indices):
  2962. if index is None:
  2963. if replacement_shape:
  2964. after_shape.append(self.shape[dim])
  2965. else:
  2966. before_shape.append(self.shape[dim])
  2967. else:
  2968. replacement_shape = list(index.shape)
  2969. def _restride_src(self):
  2970. """
  2971. This follows restride_src in TensorAdvancedIndexing.cpp
  2972. """
  2973. shape = before_shape + replacement_shape + after_shape
  2974. strides = list(self.stride())
  2975. # pyrefly: ignore [unsupported-operation]
  2976. strides[len(before_shape) : len(self.shape) - len(after_shape)] = [0] * len(
  2977. replacement_shape
  2978. )
  2979. return self.as_strided(shape, strides)
  2980. out = self.new_empty(before_shape + replacement_shape + after_shape)
  2981. from torch.fx.experimental.symbolic_shapes import guard_or_false
  2982. if guard_or_false(self.numel() == 0):
  2983. # No need to worry about the output strides if self is empty.
  2984. return out
  2985. # Try to follow eager to decide the output stride based on self.
  2986. # Note that perm here is the reverse of the 'perm_' decided by
  2987. # TensorIteratorBase::reorder_dimensions
  2988. restrided_self = _restride_src(self)
  2989. perm, _ = utils.compute_elementwise_output_logical_to_physical_perm(restrided_self)
  2990. # Follow TensorIteratorBase::allocate_or_resize_outputs
  2991. if list(perm) != list(range(len(perm))):
  2992. perm_shape = utils.apply_perm(out.shape, perm)
  2993. new_stride = utils.make_contiguous_strides_for(perm_shape)
  2994. new_stride = utils.apply_perm(new_stride, utils.invert_perm(perm))
  2995. out = out.as_strided(out.size(), new_stride)
  2996. return out
  2997. @register_meta([aten.convolution_backward.default])
  2998. def meta_convolution_backward(
  2999. grad_output_,
  3000. input_,
  3001. weight_,
  3002. bias_sizes_opt,
  3003. stride,
  3004. padding,
  3005. dilation,
  3006. transposed,
  3007. output_padding,
  3008. groups,
  3009. output_mask,
  3010. ):
  3011. # High level logic taken from slow_conv3d_backward_cpu which should
  3012. # be representative of all convolution_backward impls
  3013. backend_grad_input = None
  3014. backend_grad_weight = None
  3015. backend_grad_bias = None
  3016. if output_mask[0]:
  3017. backend_grad_input = grad_output_.new_empty(input_.size())
  3018. if output_mask[1]:
  3019. backend_grad_weight = grad_output_.new_empty(weight_.size())
  3020. if output_mask[2]:
  3021. backend_grad_bias = grad_output_.new_empty(bias_sizes_opt)
  3022. return (backend_grad_input, backend_grad_weight, backend_grad_bias)
  3023. @register_meta([aten.addbmm.default, aten.addbmm.out])
  3024. @out_wrapper(exact_dtype=True)
  3025. def meta_addbmm(self, batch1, batch2, *, beta=1, alpha=1):
  3026. dim1 = batch1.size(1)
  3027. dim2 = batch2.size(2)
  3028. self = self.expand((dim1, dim2))
  3029. torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
  3030. torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
  3031. torch._check(
  3032. batch1.size(0) == batch2.size(0),
  3033. lambda: f"batch1 and batch2 must have same number of batches, got {batch1.size(0)} and {batch2.size(0)}",
  3034. )
  3035. torch._check(
  3036. batch1.size(2) == batch2.size(1),
  3037. lambda: (
  3038. f"Incompatible matrix sizes for bmm ({batch1.size(1)}x{batch1.size(2)} "
  3039. f"and {batch2.size(1)}x{batch2.size(2)})"
  3040. ),
  3041. )
  3042. torch._check(
  3043. self.size(0) == dim1 and self.size(1) == dim2,
  3044. lambda: "self tensor does not match matmul output shape",
  3045. )
  3046. return self.new_empty(self.size())
  3047. @register_meta([aten.randint_like.Tensor])
  3048. def meta_randint_like(self, high, **kwargs):
  3049. return self.new_empty(self.size())
  3050. @register_meta([aten._fused_adam_.default, aten._fused_adamw_.default])
  3051. def meta__fused_adam_(
  3052. self,
  3053. grads,
  3054. exp_avgs,
  3055. exp_avg_sqs,
  3056. max_exp_avg_sqs,
  3057. state_steps,
  3058. *,
  3059. lr,
  3060. beta1,
  3061. beta2,
  3062. weight_decay,
  3063. eps,
  3064. amsgrad,
  3065. maximize,
  3066. grad_scale=None,
  3067. found_inf=None,
  3068. ):
  3069. for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
  3070. torch._check(
  3071. isinstance(l, list),
  3072. lambda: f"exponent must be a tensor list but got {type(l)}",
  3073. )
  3074. @register_meta([aten._fused_adam.default])
  3075. def meta__fused_adam(
  3076. self,
  3077. grads,
  3078. exp_avgs,
  3079. exp_avg_sqs,
  3080. max_exp_avg_sqs,
  3081. state_steps,
  3082. *,
  3083. lr,
  3084. beta1,
  3085. beta2,
  3086. weight_decay,
  3087. eps,
  3088. amsgrad,
  3089. maximize,
  3090. grad_scale=None,
  3091. found_inf=None,
  3092. ):
  3093. for l in [self, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps]:
  3094. torch._check(
  3095. isinstance(l, list),
  3096. lambda: f"exponent must be a tensor list but got {type(l)}",
  3097. )
  3098. def empty_like_list(tensor_list):
  3099. return [torch.empty_like(t) for t in tensor_list]
  3100. return (
  3101. empty_like_list(self),
  3102. empty_like_list(grads),
  3103. empty_like_list(exp_avgs),
  3104. empty_like_list(exp_avg_sqs),
  3105. empty_like_list(max_exp_avg_sqs),
  3106. )
  3107. @register_meta([aten._int_mm])
  3108. @out_wrapper()
  3109. def meta__int_mm(a, b):
  3110. torch._check(a.dim() == 2, lambda: "a must be a 2D tensor")
  3111. torch._check(b.dim() == 2, lambda: "b must be a 2D tensor")
  3112. torch._check(
  3113. a.dtype is torch.int8,
  3114. lambda: f"expected self to be int8, got {a.dtype}",
  3115. )
  3116. torch._check(
  3117. b.dtype is torch.int8,
  3118. lambda: f"expected mat2 to be int8, got {b.dtype}",
  3119. )
  3120. torch._check(
  3121. a.size(1) == b.size(0),
  3122. lambda: (
  3123. f"Incompatible matrix sizes for _int_mm ({a.size(0)}x{a.size(1)} "
  3124. f"and {b.size(0)}x{b.size(1)})"
  3125. ),
  3126. )
  3127. return a.new_empty((a.size(0), b.size(1)), dtype=torch.int32)
  3128. @register_meta([aten._convert_weight_to_int4pack])
  3129. def meta__convert_weight_to_int4pack(w, inner_k_tiles):
  3130. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3131. torch._check(
  3132. w.dtype is torch.uint8,
  3133. lambda: f"expected w to be uint8, got {w.dtype}",
  3134. )
  3135. n = w.size(0)
  3136. k = w.size(1) * 2 # w is [n][k / 2] uint8
  3137. return w.new_empty(
  3138. (
  3139. n // 8,
  3140. k // (inner_k_tiles * 16),
  3141. 32,
  3142. inner_k_tiles // 2,
  3143. ),
  3144. dtype=torch.int32,
  3145. )
  3146. @register_meta([aten._convert_weight_to_int4pack_for_cpu])
  3147. def meta__convert_weight_to_int4pack_for_cpu(w, inner_k_tiles):
  3148. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3149. torch._check(
  3150. w.dtype is torch.int32,
  3151. lambda: f"expected w to be int32, got {w.dtype}",
  3152. )
  3153. n = w.size(0)
  3154. k = w.size(1) # w is [n][k] int32
  3155. return w.new_empty(
  3156. (n, k // 2),
  3157. dtype=torch.uint8,
  3158. )
  3159. @register_meta([aten._weight_int4pack_mm])
  3160. def meta__weight_int4pack_mm(x, w, q_group_size, q_scale_and_zeros):
  3161. torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
  3162. torch._check(w.dim() == 4, lambda: "w must be a 4D tensor")
  3163. torch._check(
  3164. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  3165. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  3166. )
  3167. torch._check(
  3168. w.dtype is torch.int32,
  3169. lambda: f"expected w to be int32, got {w.dtype}",
  3170. )
  3171. return x.new_empty(x.size(0), w.size(0) * 8, dtype=x.dtype)
  3172. @register_meta([aten._weight_int4pack_mm_for_cpu])
  3173. def meta__weight_int4pack_mm_for_cpu(x, w, q_group_size, q_scale_and_zeros):
  3174. torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
  3175. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3176. torch._check(
  3177. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  3178. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  3179. )
  3180. torch._check(
  3181. w.dtype is torch.uint8,
  3182. lambda: f"expected w to be uint8, got {w.dtype}",
  3183. )
  3184. return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
  3185. @register_meta([aten._weight_int4pack_mm_with_scales_and_zeros])
  3186. def _weight_int4pack_mm_with_scales_and_zeros(x, w, q_group_size, qScale, qZeros):
  3187. torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
  3188. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3189. torch._check(
  3190. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  3191. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  3192. )
  3193. torch._check(
  3194. w.dtype is torch.int32,
  3195. lambda: f"expected w to be int32, got {w.dtype}",
  3196. )
  3197. return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
  3198. def kai_roundup(a: int, b: int) -> int:
  3199. return ((a + b - 1) // b) * b
  3200. def get_kai_packed_weight_size(n_bits, N, K, groupsize):
  3201. if n_bits == 4:
  3202. # Works for both fp32 and bf16 Kernels
  3203. if groupsize == K: # channelwise
  3204. # dotprod params only [1x8x32_neon_dotprod]
  3205. kai_nr = 8
  3206. kai_kr = 16
  3207. kai_sr = 2
  3208. kai_num_bytes_sum_rhs = 4 # sizeof(int32_t)
  3209. kai_num_bytes_multiplier_rhs = 4 # sizeof(float)
  3210. kai_num_bytes_bias = 4 # sizeof(float)
  3211. def kai_k_roundedup(k, kr, sr):
  3212. # Since we pack a float and int32 value at the end of the row,
  3213. # we must make sure that k is a multiple of 4 for alignment
  3214. kr_sr_roundedup4 = kai_roundup(kr * sr, 4)
  3215. return kai_roundup(k, kr_sr_roundedup4)
  3216. def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
  3217. k, nr, kr, sr
  3218. ):
  3219. k_internal = kai_k_roundedup(k, kr, sr)
  3220. assert (k_internal % 2) == 0, "k_internal must be even"
  3221. return nr * (
  3222. (k_internal // 2)
  3223. + kai_num_bytes_multiplier_rhs
  3224. + kai_num_bytes_sum_rhs
  3225. + kai_num_bytes_bias
  3226. )
  3227. def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
  3228. n, k, nr, kr, sr
  3229. ):
  3230. num_rows = kai_roundup(n, nr) // nr
  3231. return (
  3232. num_rows
  3233. * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
  3234. k, nr, kr, sr
  3235. )
  3236. )
  3237. return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(
  3238. N, K, kai_nr, kai_kr, kai_sr
  3239. )
  3240. elif groupsize % 32 == 0 and K % groupsize == 0: # groupwise
  3241. kai_nr = 8
  3242. kai_kr = 16
  3243. kai_sr = 2
  3244. kai_num_bytes_sum_rhs = 4
  3245. kai_num_bytes_bias = 4
  3246. kai_nr_multiple_of = 4
  3247. kai_bl_multiple_of = 32
  3248. def kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
  3249. n, k, nr, kr, sr, bl
  3250. ):
  3251. assert (bl % kr) == 0
  3252. assert (nr % kai_nr_multiple_of) == 0
  3253. assert (bl % kai_bl_multiple_of) == 0
  3254. num_rows = kai_roundup(n, nr) // nr
  3255. return (
  3256. num_rows
  3257. * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
  3258. k, nr, kr, sr, bl
  3259. )
  3260. )
  3261. def kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
  3262. k, nr, kr, sr, bl
  3263. ):
  3264. assert (bl % kr) == 0
  3265. assert (nr % kai_nr_multiple_of) == 0
  3266. assert (bl % kai_bl_multiple_of) == 0
  3267. # kr and sr are unused in the calculation
  3268. num_bytes_multiplier_rhs = kai_get_bf16_datatype_size_in_bytes()
  3269. num_blocks_per_row = kai_num_blocks_per_row(k, bl)
  3270. num_bytes_per_block = kai_num_bytes_per_block(
  3271. bl, num_bytes_multiplier_rhs
  3272. )
  3273. return nr * (
  3274. (num_bytes_per_block * num_blocks_per_row)
  3275. + kai_num_bytes_sum_rhs
  3276. + kai_num_bytes_bias
  3277. )
  3278. # This function returns size of these datatypes stored as enum. We modify it to just return bf16 datatype
  3279. # https://gitlab.arm.com/kleidi/kleidiai/-/blob/main/kai/kai_common.h?ref_type=heads#L55
  3280. def kai_get_bf16_datatype_size_in_bytes():
  3281. return 2 # 2 bytes
  3282. def kai_num_blocks_per_row(k, bl):
  3283. assert (bl % kai_bl_multiple_of) == 0
  3284. return kai_roundup(k, bl) // bl
  3285. def kai_num_bytes_per_block(bl, num_bytes_multiplier_rhs):
  3286. assert (bl % kai_bl_multiple_of) == 0
  3287. return (bl // 2) + num_bytes_multiplier_rhs
  3288. return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0(
  3289. N, K, kai_nr, kai_kr, kai_sr, groupsize
  3290. )
  3291. @register_meta([aten._dyn_quant_pack_4bit_weight])
  3292. def meta__dyn_quant_pack_4bit_weight(
  3293. weights, scales_zeros, bias: Tensor | None, block_size, in_features, out_features
  3294. ):
  3295. torch._check(
  3296. weights.dtype is torch.uint8,
  3297. lambda: f"expected w to be uint8, got {weights.dtype}",
  3298. )
  3299. if torch.backends.kleidiai.is_available() and (
  3300. (block_size == in_features and scales_zeros.dtype == torch.float)
  3301. or (
  3302. block_size < in_features
  3303. and block_size % 32 == 0
  3304. and in_features % block_size == 0
  3305. and scales_zeros.dtype == torch.bfloat16
  3306. )
  3307. ):
  3308. packed_weight_size = get_kai_packed_weight_size(
  3309. 4, out_features, in_features, block_size
  3310. )
  3311. return weights.new_empty(int(packed_weight_size), dtype=torch.uint8)
  3312. packed_weight_size = weights.numel() + scales_zeros.numel()
  3313. if bias is not None:
  3314. packed_weight_size += bias.numel()
  3315. return weights.new_empty(packed_weight_size, dtype=torch.float)
  3316. @register_meta([aten._dyn_quant_matmul_4bit])
  3317. def meta__dyn_quant_matmul_4bit(
  3318. inp,
  3319. packed_weights,
  3320. block_size,
  3321. in_features,
  3322. out_features,
  3323. ):
  3324. torch._check(inp.dim() == 2, lambda: "input must be a 2D tensor")
  3325. torch._check(
  3326. (inp.dtype == torch.float32)
  3327. or (inp.dtype == torch.bfloat16 and block_size == in_features),
  3328. lambda: (
  3329. f"expected input to be f32 or bf16 (bf16 requires block_size == in_features), "
  3330. f"got {inp.dtype} with block_size={block_size} and in_features={in_features}"
  3331. ),
  3332. )
  3333. M = inp.size(0)
  3334. return inp.new_empty(M, out_features, dtype=inp.dtype)
  3335. @register_meta([aten._weight_int8pack_mm])
  3336. def meta__weight_int8pack_mm(x, w, q_scales):
  3337. torch._check(x.dim() == 2, lambda: "x must be a 2D tensor")
  3338. torch._check(
  3339. x.dtype in [torch.float32, torch.float16, torch.bfloat16],
  3340. lambda: f"expected x to be f32/f16/bf16, got {x.dtype}",
  3341. )
  3342. torch._check(w.dim() == 2, lambda: "w must be a 2D tensor")
  3343. torch._check(
  3344. w.dtype is torch.int8,
  3345. lambda: f"expected w to be int8, got {w.dtype}",
  3346. )
  3347. return x.new_empty(x.size(0), w.size(0), dtype=x.dtype)
  3348. @register_meta(aten._cdist_forward.default)
  3349. def meta_cdist_forward(x1, x2, p, compute_mode):
  3350. torch._check(
  3351. x1.dim() >= 2,
  3352. lambda: f"cdist only supports at least 2D tensors, X1 got: {x1.dim()}D",
  3353. )
  3354. torch._check(
  3355. x2.dim() >= 2,
  3356. lambda: f"cdist only supports at least 2D tensors, X2 got: {x2.dim()}D",
  3357. )
  3358. torch._check(
  3359. x1.size(-1) == x2.size(-1),
  3360. lambda: f"X1 and X2 must have the same number of columns. X1: {x1.size(-1)} X2: {x2.size(-1)}",
  3361. )
  3362. torch._check(
  3363. utils.is_float_dtype(x1.dtype),
  3364. lambda: f"cdist only supports floating-point dtypes, X1 got: {x1.dtype}",
  3365. )
  3366. torch._check(
  3367. utils.is_float_dtype(x2.dtype),
  3368. lambda: f"cdist only supports floating-point dtypes, X2 got: {x2.dtype}",
  3369. )
  3370. torch._check(p >= 0, lambda: "cdist only supports non-negative p values")
  3371. torch._check(
  3372. compute_mode in (None, 0, 1, 2),
  3373. lambda: f"possible modes: None, 0, 1, 2, but was: {compute_mode}",
  3374. )
  3375. r1 = x1.size(-2)
  3376. r2 = x2.size(-2)
  3377. batch_tensor1 = x1.shape[:-2]
  3378. batch_tensor2 = x2.shape[:-2]
  3379. output_shape = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
  3380. output_shape.extend([r1, r2])
  3381. return x1.new_empty(output_shape)
  3382. @register_meta(aten._cdist_backward)
  3383. @out_wrapper()
  3384. def meta_cdist_backward(grad, x1, x2, p, cdist):
  3385. c1 = x1.shape[-1]
  3386. r1 = x1.shape[-2]
  3387. r2 = x2.shape[-2]
  3388. batch_tensor1 = x1.shape[:-2]
  3389. batch_tensor2 = x2.shape[:-2]
  3390. expand_batch_portion = list(torch.broadcast_shapes(batch_tensor1, batch_tensor2))
  3391. tensor1_expand_size = expand_batch_portion.copy()
  3392. tensor1_expand_size.extend([r1, c1])
  3393. batch_product = math.prod(expand_batch_portion)
  3394. if r1 == 0 or r2 == 0 or c1 == 0 or batch_product == 0:
  3395. return torch.zeros_like(x1)
  3396. if tensor1_expand_size != list(x1.shape):
  3397. x1 = x1.expand(tensor1_expand_size)
  3398. return torch.empty_like(x1, memory_format=torch.contiguous_format)
  3399. # NB: This meta function accepts non-meta arguments! When this behavior
  3400. # was originally introduced this was accidental, but it is now load bearing
  3401. # as people are using this so that they can conveniently test code involving
  3402. # embeddings (feeding CPU tensor inputs with meta device EmbeddingBag module)
  3403. @register_meta(aten._embedding_bag.default)
  3404. def meta_embedding_bag(
  3405. weight,
  3406. indices,
  3407. offsets,
  3408. scale_grad_by_freq=False,
  3409. mode=0,
  3410. sparse=False,
  3411. per_sample_weights=None,
  3412. include_last_offset=False,
  3413. padding_idx=-1,
  3414. ):
  3415. torch._check(
  3416. indices.dtype in (torch.long, torch.int),
  3417. lambda: f"expected indices to be long or int, got {indices.dtype}",
  3418. )
  3419. torch._check(
  3420. offsets.dtype in (torch.long, torch.int),
  3421. lambda: f"expected offsets to be long or int, got {offsets.dtype}",
  3422. )
  3423. torch._check(
  3424. utils.is_float_dtype(weight.dtype),
  3425. lambda: f"expected weight to be floating point type, got {weight.dtype}",
  3426. )
  3427. num_bags = offsets.size(0)
  3428. if include_last_offset:
  3429. torch._check(
  3430. num_bags >= 1,
  3431. lambda: "include_last_offset: numBags should be at least 1",
  3432. )
  3433. num_bags -= 1
  3434. output = weight.new_empty(num_bags, weight.size(1))
  3435. if per_sample_weights is not None:
  3436. torch._check(
  3437. mode == MODE_SUM,
  3438. lambda: "embedding_bag: per_sample_weights only supported with mode='sum'",
  3439. )
  3440. torch._check(
  3441. per_sample_weights.ndim == 1,
  3442. lambda: f"expected per_sample_weights to be 1D tensor, got {per_sample_weights.ndim}D",
  3443. )
  3444. torch._check(
  3445. per_sample_weights.numel() == indices.numel(),
  3446. lambda: (
  3447. f"expected per_sample_weights.numel() ({per_sample_weights.numel()} "
  3448. f"to be the same as indices.numel() ({indices.numel()})"
  3449. ),
  3450. )
  3451. def is_fast_path_index_select_scale(src, scale, output, padding_idx):
  3452. return (
  3453. is_fast_path_index_select(src, output, padding_idx) and scale.stride(0) == 1
  3454. )
  3455. def is_fast_path_index_select(src, output, padding_idx):
  3456. return (
  3457. (src.dtype == torch.float or src.dtype == torch.half)
  3458. and src.stride(1) == 1
  3459. and output.stride(1) == 1
  3460. and padding_idx < 0
  3461. )
  3462. def is_fast_path(src, scale, output, padding_idx):
  3463. if scale is not None:
  3464. return is_fast_path_index_select_scale(src, scale, output, padding_idx)
  3465. else:
  3466. return is_fast_path_index_select(src, output, padding_idx)
  3467. if device_hint(offsets) != "cpu":
  3468. offset2bag = indices.new_empty(indices.size(0))
  3469. bag_size = indices.new_empty(offsets.size())
  3470. if mode == MODE_MAX:
  3471. max_indices = indices.new_empty(num_bags, weight.size(1))
  3472. else:
  3473. max_indices = indices.new_empty(0)
  3474. else:
  3475. fast_path_sum = is_fast_path(weight, per_sample_weights, output, padding_idx)
  3476. if mode in (MODE_MEAN, MODE_MAX) or not fast_path_sum:
  3477. offset2bag = offsets.new_empty(indices.size(0))
  3478. else:
  3479. offset2bag = offsets.new_empty(0)
  3480. bag_size = offsets.new_empty(num_bags)
  3481. # This part of the logic comes from make_max_indices_out in EmbeddingBag.cpp
  3482. numBags = offsets.shape[0]
  3483. if mode == MODE_MAX:
  3484. if include_last_offset:
  3485. torch._check(
  3486. numBags >= 1,
  3487. lambda: "include_last_offset: numBags should be at least 1",
  3488. )
  3489. numBags -= 1
  3490. max_indices = offsets.new_empty(numBags, weight.shape[1])
  3491. else:
  3492. max_indices = offsets.new_empty(bag_size.size())
  3493. return output, offset2bag, bag_size, max_indices
  3494. @register_meta(aten._embedding_bag_forward_only.default)
  3495. def meta_embedding_bag_forward_only(weight, indices, offsets, *args):
  3496. output, offset2bag, bag_size, max_indices = meta_embedding_bag(
  3497. weight, indices, offsets, *args
  3498. )
  3499. if device_hint(offsets) == "cpu":
  3500. bag_size = offsets.new_empty(offsets.size())
  3501. return output, offset2bag, bag_size, max_indices
  3502. def _get_reduction_dtype(input, dtype, promote_int_to_long=True):
  3503. # if specified, dtype takes precedence
  3504. if dtype:
  3505. return dtype
  3506. if input.dtype.is_floating_point or input.dtype.is_complex:
  3507. return input.dtype
  3508. elif promote_int_to_long:
  3509. return torch.long
  3510. return input.dtype
  3511. @register_meta([aten.nansum.default, aten.nansum.out])
  3512. @out_wrapper()
  3513. def meta_nansum(input, dims=None, keepdim=False, *, dtype=None):
  3514. output_dtype = _get_reduction_dtype(input, dtype, promote_int_to_long=True)
  3515. dims = utils.reduction_dims(input.shape, dims)
  3516. output_shape = _compute_reduction_shape(input, dims, keepdim)
  3517. return input.new_empty(output_shape, dtype=output_dtype)
  3518. @register_meta([aten.median.default, aten.nanmedian.default])
  3519. def meta_median(input):
  3520. output_shape = utils.compute_reduction_output_shape(
  3521. input.shape, tuple(range(input.dim()))
  3522. )
  3523. return input.new_empty(output_shape)
  3524. @register_meta(
  3525. [
  3526. aten.median.dim,
  3527. aten.median.dim_values,
  3528. aten.nanmedian.dim,
  3529. aten.nanmedian.dim_values,
  3530. aten.mode.default,
  3531. aten.mode.values,
  3532. ]
  3533. )
  3534. @out_wrapper("values", "indices")
  3535. def meta_median_mode_dim(input, dim=-1, keepdim=False):
  3536. if device_hint(input) == "cuda":
  3537. utils.alert_not_deterministic("median CUDA with indices output")
  3538. dim = utils.reduction_dims(input.shape, (dim,))
  3539. output_shape = _compute_reduction_shape(input, dim, keepdim)
  3540. return (
  3541. input.new_empty(output_shape),
  3542. input.new_empty(output_shape, dtype=torch.long),
  3543. )
  3544. @register_meta(aten.logical_not_.default)
  3545. def meta_logical_not_(self):
  3546. return self
  3547. @register_meta(aten.repeat.default)
  3548. def meta_repeat(self, repeats):
  3549. torch._check(
  3550. len(repeats) >= self.dim(),
  3551. lambda: "Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor",
  3552. )
  3553. for i, rep in enumerate(repeats):
  3554. torch._check(
  3555. rep >= 0,
  3556. lambda: f"Repeats cannot be negative, found {rep} at index {i}",
  3557. )
  3558. # Add new leading dimensions to the tensor if the
  3559. # number of target dimensions is larger than the
  3560. # number of source dimensions.
  3561. num_new_dimensions = len(repeats) - self.dim()
  3562. padded_size = (1,) * num_new_dimensions + tuple(self.shape)
  3563. target_size = [padded_size[i] * repeats[i] for i in range(len(repeats))]
  3564. return self.new_empty(target_size)
  3565. @register_meta(aten.zero_.default)
  3566. def meta_zero_(self):
  3567. return self
  3568. @register_meta(
  3569. [
  3570. aten.mul_.Scalar,
  3571. aten.div_.Scalar,
  3572. aten.mul_.Tensor,
  3573. aten.div_.Tensor,
  3574. aten.logical_and_.default,
  3575. aten.logical_or_.default,
  3576. aten.logical_xor_.default,
  3577. ],
  3578. )
  3579. def meta_binop_inplace(self, other):
  3580. if isinstance(other, torch.Tensor):
  3581. check_inplace_broadcast(self.shape, other.shape)
  3582. return self
  3583. @register_meta(
  3584. [
  3585. aten.add_.Scalar,
  3586. aten.sub_.Scalar,
  3587. aten.add_.Tensor,
  3588. aten.sub_.Tensor,
  3589. ],
  3590. )
  3591. def meta_binop_inplace_alpha(self, other, alpha=1):
  3592. """
  3593. Some checks for inplace ops.
  3594. Checks for promotion rules for some dtypes.
  3595. int.add/sub_(float) and bool.add/sub_(others) are rejected.
  3596. Promoting in these in-place operations would require reallocating
  3597. and copying over elements, hence not allowed.
  3598. Checks for alpha param.
  3599. """
  3600. def is_integeric(arg):
  3601. if isinstance(arg, TensorLike):
  3602. return utils.is_integer_dtype(arg.dtype)
  3603. else:
  3604. return isinstance(arg, IntLike)
  3605. def is_floatic(arg):
  3606. if isinstance(arg, TensorLike):
  3607. return utils.is_float_dtype(arg.dtype)
  3608. else:
  3609. return isinstance(arg, FloatLike)
  3610. def is_booleanic(arg):
  3611. if isinstance(arg, TensorLike):
  3612. return utils.is_boolean_dtype(arg.dtype)
  3613. else:
  3614. return isinstance(arg, BoolLike)
  3615. # Do not allow int+float->int in-place
  3616. if is_integeric(self) and is_floatic(other):
  3617. raise RuntimeError(
  3618. "Promotion of int.add/sub_(float) in in-place ops are not possible due to element size change."
  3619. )
  3620. # Do not allow bool+other->bool in-place
  3621. if is_booleanic(self) and not is_booleanic(other):
  3622. raise RuntimeError(
  3623. "Promotion of book.add/sub_(others) in in-place ops are not possible due to element size change."
  3624. )
  3625. if isinstance(other, torch.Tensor):
  3626. check_inplace_broadcast(self.shape, other.shape)
  3627. return self
  3628. @register_meta(
  3629. [
  3630. aten.add.Scalar,
  3631. aten.sub.Scalar,
  3632. ],
  3633. )
  3634. def meta_binop_alpha(self, other, alpha=1):
  3635. return elementwise_meta(
  3636. self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3637. )
  3638. @register_meta([aten.round.default, aten.round.decimals])
  3639. def meta_round(self, **kwargs):
  3640. return elementwise_meta(
  3641. self, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3642. )
  3643. def shift_dtype_check(fn_name, self, val):
  3644. torch._check(
  3645. utils.is_integer_dtype(self.dtype),
  3646. lambda: f"{fn_name}: Expected input tensor to have an integral dtype. Got {self.dtype}",
  3647. )
  3648. if isinstance(val, torch.Tensor):
  3649. torch._check(
  3650. utils.is_integer_dtype(val.dtype),
  3651. lambda: f"{fn_name}: Expected shift value to have an integral dtype. Got {val.dtype}",
  3652. )
  3653. else:
  3654. torch._check(
  3655. isinstance(val, IntLike),
  3656. lambda: f"{fn_name}: Expected shift value to be an int. Got {val}",
  3657. )
  3658. @register_meta([aten.__rshift__.Tensor, aten.__rshift__.Scalar])
  3659. def meta_rshifts(self, other):
  3660. shift_dtype_check("rshift", self, other)
  3661. return elementwise_meta(
  3662. self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3663. )
  3664. @register_meta([aten.__lshift__.Tensor, aten.__lshift__.Scalar])
  3665. def meta_lshifts(self, other):
  3666. shift_dtype_check("lshift", self, other)
  3667. return elementwise_meta(
  3668. self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3669. )
  3670. @register_meta(aten.zero.default)
  3671. def meta_zero(self):
  3672. return self.new_empty(self.shape)
  3673. @register_meta([aten.fill_.Tensor, aten.fill_.Scalar])
  3674. def meta_fill_(self, val):
  3675. return self
  3676. @register_meta([aten.fill.Tensor, aten.fill.Scalar])
  3677. def meta_fill(self, val):
  3678. return torch.empty_like(self)
  3679. @register_meta(aten.relu_.default)
  3680. def meta_relu_(self):
  3681. return self
  3682. @register_meta(aten._add_relu.Tensor)
  3683. @out_wrapper()
  3684. def meta__add_relu(self, other, alpha=1) -> Tensor:
  3685. return elementwise_meta(
  3686. self, other, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  3687. )
  3688. @register_meta([aten.rrelu_with_noise])
  3689. @out_wrapper()
  3690. def meta_rrelu_with_noise(
  3691. self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
  3692. ):
  3693. return torch.empty_like(self)
  3694. @register_meta([aten.rrelu_with_noise_functional])
  3695. def meta_rrelu_with_noise_functional(
  3696. self, noise, lower=0.125, upper=0.3333333333333333, training=False, generator=None
  3697. ):
  3698. return torch.empty_like(self), torch.empty_like(noise)
  3699. @register_meta([aten.rrelu_with_noise_])
  3700. def meta_rrelu_with_noise_(
  3701. self, lower=0.125, upper=0.3333333333333333, training=False, generator=None
  3702. ):
  3703. return self
  3704. @register_meta([aten.index_put.default, aten._unsafe_index_put.default])
  3705. def meta_index_put(self, indices, values, accumulate=False):
  3706. return torch.empty_like(self)
  3707. @register_meta(aten.masked_fill_.Scalar)
  3708. def meta_masked_fill_(self, mask, value):
  3709. check_inplace_broadcast(self.shape, mask.shape)
  3710. return self
  3711. @register_meta(aten._masked_scale.default)
  3712. def meta__masked_scale(self, mask, scale):
  3713. masked_scale = self.new_empty(self.size()).to(
  3714. memory_format=utils.suggest_memory_format(self)
  3715. )
  3716. return masked_scale
  3717. @register_meta(aten.masked_scatter_)
  3718. def meta_masked_scatter_(self, mask, source):
  3719. torch._check(
  3720. mask.dtype in (torch.bool, torch.uint8), lambda: "Mask must be bool or uint8"
  3721. )
  3722. torch._check(
  3723. self.dtype == source.dtype,
  3724. lambda: "masked_scatter: expected self and source to have same "
  3725. f"dtypes but got {self.dtype} and {source.dtype}",
  3726. )
  3727. return self
  3728. @register_meta(aten.masked_scatter)
  3729. @out_wrapper()
  3730. def meta_masked_scatter(self, mask, source):
  3731. self, mask = _maybe_broadcast(self, mask)
  3732. output = torch.empty_like(self, memory_format=torch.contiguous_format)
  3733. return meta_masked_scatter_(output, mask, source)
  3734. @register_meta(aten.masked_scatter_backward)
  3735. def meta_masked_scatter_backward(self, mask, sizes):
  3736. return self.new_empty(sizes)
  3737. @register_meta(aten.index_put_.default)
  3738. def meta_index_put_(self, indices, values, accumulate=False):
  3739. return self
  3740. def common_meta_baddbmm_bmm(batch1, batch2, is_bmm, self_baddbmm=None, out_dtype=None):
  3741. from torch.fx.experimental.symbolic_shapes import sym_and, sym_eq
  3742. torch._check(batch1.dim() == 3, lambda: "batch1 must be a 3D tensor")
  3743. torch._check(batch2.dim() == 3, lambda: "batch2 must be a 3D tensor")
  3744. batch1_sizes = batch1.size()
  3745. batch2_sizes = batch2.size()
  3746. bs = batch1_sizes[0]
  3747. contraction_size = batch1_sizes[2]
  3748. res_rows = batch1_sizes[1]
  3749. res_cols = batch2_sizes[2]
  3750. output_size = (bs, res_rows, res_cols)
  3751. torch._check(
  3752. sym_and(sym_eq(batch2_sizes[0], bs), sym_eq(batch2_sizes[1], contraction_size)),
  3753. lambda: f"Expected size for first two dimensions of batch2 tensor to be: [{bs}"
  3754. f", {contraction_size}] but got: [{batch2_sizes[0]}, {batch2_sizes[1]}].",
  3755. )
  3756. if out_dtype:
  3757. supported_out_dtype = (
  3758. batch1.dtype == torch.float16 or batch1.dtype == torch.bfloat16
  3759. ) and out_dtype == torch.float32
  3760. torch._check(
  3761. out_dtype == batch1.dtype or supported_out_dtype,
  3762. lambda: "out_dtype only supported for torch.float32 output with float16/bfloat16 inputs or same as input dtypes",
  3763. )
  3764. output = batch2.new_empty(output_size).to(out_dtype)
  3765. else:
  3766. # TODO: handle out
  3767. output = batch2.new_empty(output_size)
  3768. if not is_bmm and self_baddbmm is not None:
  3769. torch._check(self_baddbmm.dim() == 3, lambda: "self must be a 3D tensor")
  3770. torch._check(
  3771. sym_eq(self_baddbmm.size(), output_size),
  3772. lambda: f"Expected an input tensor shape with shape {output_size} but got shape: {self_baddbmm.size()}",
  3773. )
  3774. return output
  3775. @register_meta(aten.bmm.default)
  3776. def meta_bmm(self, mat2):
  3777. return common_meta_baddbmm_bmm(self, mat2, True)
  3778. @register_meta(aten.bmm.dtype)
  3779. def meta_bmm_dtype(self, mat2, out_dtype):
  3780. return common_meta_baddbmm_bmm(self, mat2, True, out_dtype=out_dtype)
  3781. def div_rtn(x, y):
  3782. q = x // y
  3783. r = x % y
  3784. # WARNING: explicit bool conversion here is necessary;
  3785. # would be fixed by SymBool
  3786. if r != 0 and (bool(r < 0) != bool(y < 0)):
  3787. q -= 1
  3788. return q
  3789. def pooling_output_shape_pad_lr(
  3790. inputSize,
  3791. kernelSize,
  3792. pad_l,
  3793. pad_r,
  3794. stride,
  3795. dilation,
  3796. ceil_mode,
  3797. ):
  3798. outputSize = (
  3799. div_rtn(
  3800. inputSize
  3801. + pad_l
  3802. + pad_r
  3803. - dilation * (kernelSize - 1)
  3804. - 1
  3805. + (stride - 1 if ceil_mode else 0),
  3806. stride,
  3807. )
  3808. + 1
  3809. )
  3810. if ceil_mode:
  3811. if (outputSize - 1) * stride >= inputSize + pad_l:
  3812. outputSize -= 1
  3813. return outputSize
  3814. def pooling_output_shape(inputSize, kernelSize, pad, stride, dilation, ceil_mode):
  3815. torch._check(stride != 0, lambda: "stride should not be zero")
  3816. torch._check(pad >= 0, lambda: f"pad must be non-negative, but got pad: {pad}")
  3817. torch._check(
  3818. pad <= ((kernelSize - 1) * dilation + 1) // 2,
  3819. lambda: (
  3820. f"pad should be at most half of effective kernel size, but got pad={pad}, "
  3821. f"kernel_size={kernelSize} and dilation={dilation}"
  3822. ),
  3823. )
  3824. return pooling_output_shape_pad_lr(
  3825. inputSize, kernelSize, pad, pad, stride, dilation, ceil_mode
  3826. )
  3827. def pool2d_shape_check(
  3828. input,
  3829. kH,
  3830. kW,
  3831. dH,
  3832. dW,
  3833. padH,
  3834. padW,
  3835. dilationH,
  3836. dilationW,
  3837. nInputPlane,
  3838. inputHeight,
  3839. inputWidth,
  3840. outputHeight,
  3841. outputWidth,
  3842. memory_format,
  3843. ):
  3844. ndim = input.dim()
  3845. nOutputPlane = nInputPlane
  3846. torch._check(
  3847. kW > 0 and kH > 0,
  3848. lambda: f"kernel size should be greater than zero, but got kH: {kH}, kW: {kW}",
  3849. )
  3850. torch._check(
  3851. dW > 0 and dH > 0,
  3852. lambda: f"stride should be greater than zero, but got dH: {dH}, dW: {dW}",
  3853. )
  3854. torch._check(
  3855. dilationH > 0 and dilationW > 0,
  3856. lambda: f"dilation should be greater than zero, but got dilationH: {dilationH}, dilationW: {dilationW}",
  3857. )
  3858. valid_dims = input.size(1) != 0 and input.size(2) != 0
  3859. if memory_format == torch.channels_last:
  3860. torch._check(
  3861. ndim == 4 and valid_dims and input.size(3) != 0,
  3862. lambda: "Expected 4D (batch mode) tensor expected for input with channels_last layout"
  3863. f" with optional 0 dim batch size for input, but got: {input.size()}",
  3864. )
  3865. else:
  3866. torch._check(
  3867. (ndim == 3 and input.size(0) != 0 and valid_dims)
  3868. or (ndim == 4 and valid_dims and input.size(3) != 0),
  3869. lambda: f"Expected 3D or 4D (batch mode) tensor with optional 0 dim batch size for input, but got: {input.size()}",
  3870. )
  3871. torch._check(
  3872. kW // 2 >= padW and kH // 2 >= padH,
  3873. lambda: "pad should be smaller than or equal to half of kernel size, but got "
  3874. f"padW = {padW}, padH = {padH}, kW = {kW}, kH = {kH}",
  3875. )
  3876. torch._check(
  3877. outputWidth >= 1 and outputHeight >= 1,
  3878. lambda: f"Given input size: ({nInputPlane}x{inputHeight}x{inputWidth}). "
  3879. f"Calculated output size: ({nOutputPlane}x{outputHeight}x{outputWidth}). "
  3880. "Output size is too small",
  3881. )
  3882. def pool3d_shape_check(
  3883. input: Tensor,
  3884. nslices: int,
  3885. kT: int,
  3886. kH: int,
  3887. kW: int,
  3888. dT: int,
  3889. dH: int,
  3890. dW: int,
  3891. pT: int,
  3892. pH: int,
  3893. pW: int,
  3894. dilationT: int,
  3895. dilationH: int,
  3896. dilationW: int,
  3897. itime: int,
  3898. iheight: int,
  3899. iwidth: int,
  3900. otime: int,
  3901. oheight: int,
  3902. owidth: int,
  3903. fn_name: str,
  3904. check_input_size: bool = False,
  3905. ):
  3906. ndim = input.ndim
  3907. torch._check(
  3908. kT > 0 and kW > 0 and kH > 0,
  3909. lambda: (
  3910. f"kernel size should be greater than zero, but got "
  3911. f"kT: {kT}, kH: {kH}, kW: {kW}"
  3912. ),
  3913. )
  3914. torch._check(
  3915. dT > 0 and dW > 0 and dH > 0,
  3916. lambda: (
  3917. f"stride should be greater than zero, but got dT: {dT}, dH: {dH}, dW: {dW}"
  3918. ),
  3919. )
  3920. torch._check(
  3921. dilationT > 0 and dilationW > 0 and dilationH > 0,
  3922. lambda: (
  3923. f"dilation should be greater than zero, but got "
  3924. f"dilationT: {dilationT}, dilationH: {dilationH}, dilationW: {dilationW}"
  3925. ),
  3926. )
  3927. torch._check(
  3928. ndim in (4, 5),
  3929. lambda: f"{fn_name}: Expected 4D or 5D tensor for input, but got: {input.shape}",
  3930. )
  3931. for i in range(ndim):
  3932. if ndim == 5 and i == 0:
  3933. # size of batch-dim can be 0.
  3934. continue
  3935. torch._check(
  3936. input.size(i) > 0,
  3937. lambda: (
  3938. f"{fn_name}: Expected input's non-batch dimensions to have positive length,"
  3939. f" but input has a shape of {input.shape}"
  3940. f" and non-batch dimension {input.size(i)} has length zero!"
  3941. ),
  3942. )
  3943. if check_input_size: # AveragePool3d
  3944. torch._check(
  3945. itime >= kT and iheight >= kH and iwidth >= kW,
  3946. lambda: (
  3947. f"input image (T: {itime} H: {iheight} W: {iwidth}) smaller than "
  3948. f"kernel size (kT: {kT} kH: {kH} kW: {kW})"
  3949. ),
  3950. )
  3951. torch._check(
  3952. kT / 2 >= pT and kW / 2 >= pW and kH / 2 >= pH,
  3953. lambda: (
  3954. f"pad should be smaller than or equal to half of kernel size, but got "
  3955. f"kT: {kT} kW: {kW} kH: {kH} padT: {pT} padW: {pW} padH: {pH}"
  3956. ),
  3957. )
  3958. torch._check(
  3959. otime >= 1 and owidth >= 1 and oheight >= 1,
  3960. lambda: (
  3961. f"Given input size: ({nslices}x{itime}x{iheight}x{iwidth}). "
  3962. f"Calculated output size: ({nslices}x{otime}x{oheight}x{owidth}). "
  3963. f"Output size is too small"
  3964. ),
  3965. )
  3966. def max_pool3d_backward_shape_check(
  3967. input,
  3968. grad_output,
  3969. indices,
  3970. nslices,
  3971. kT,
  3972. kH,
  3973. kW,
  3974. dT,
  3975. dH,
  3976. dW,
  3977. pT,
  3978. pH,
  3979. pW,
  3980. dilationT,
  3981. dilationH,
  3982. dilationW,
  3983. itime,
  3984. iheight,
  3985. iwidth,
  3986. otime,
  3987. oheight,
  3988. owidth,
  3989. fn_name,
  3990. ):
  3991. ndim = input.ndim
  3992. pool3d_shape_check(
  3993. input,
  3994. nslices,
  3995. kT,
  3996. kH,
  3997. kW,
  3998. dT,
  3999. dH,
  4000. dW,
  4001. pT,
  4002. pH,
  4003. pW,
  4004. dilationT,
  4005. dilationH,
  4006. dilationW,
  4007. itime,
  4008. iheight,
  4009. iwidth,
  4010. otime,
  4011. oheight,
  4012. owidth,
  4013. fn_name,
  4014. )
  4015. check_dim_size(grad_output, ndim, ndim - 4, nslices)
  4016. check_dim_size(grad_output, ndim, ndim - 3, otime)
  4017. check_dim_size(grad_output, ndim, ndim - 2, oheight)
  4018. check_dim_size(grad_output, ndim, ndim - 1, owidth)
  4019. check_dim_size(indices, ndim, ndim - 4, nslices)
  4020. check_dim_size(indices, ndim, ndim - 3, otime)
  4021. check_dim_size(indices, ndim, ndim - 2, oheight)
  4022. check_dim_size(indices, ndim, ndim - 1, owidth)
  4023. def avg_pool3d_backward_shape_check(
  4024. input: Tensor,
  4025. grad_output: Tensor,
  4026. nslices: int,
  4027. kT: int,
  4028. kH: int,
  4029. kW: int,
  4030. dT: int,
  4031. dH: int,
  4032. dW: int,
  4033. pT: int,
  4034. pH: int,
  4035. pW: int,
  4036. itime: int,
  4037. iheight: int,
  4038. iwidth: int,
  4039. otime: int,
  4040. oheight: int,
  4041. owidth: int,
  4042. fn_name: str,
  4043. ):
  4044. ndim = input.ndim
  4045. pool3d_shape_check(
  4046. input,
  4047. nslices,
  4048. kT,
  4049. kH,
  4050. kW,
  4051. dT,
  4052. dH,
  4053. dW,
  4054. pT,
  4055. pH,
  4056. pW,
  4057. 1,
  4058. 1,
  4059. 1,
  4060. itime,
  4061. iheight,
  4062. iwidth,
  4063. otime,
  4064. oheight,
  4065. owidth,
  4066. fn_name,
  4067. True,
  4068. )
  4069. check_dim_size(grad_output, ndim, ndim - 4, nslices)
  4070. check_dim_size(grad_output, ndim, ndim - 3, otime)
  4071. check_dim_size(grad_output, ndim, ndim - 2, oheight)
  4072. check_dim_size(grad_output, ndim, ndim - 1, owidth)
  4073. def max_pool2d_checks_and_compute_shape(
  4074. input,
  4075. kernel_size,
  4076. stride,
  4077. padding,
  4078. dilation,
  4079. ceil_mode,
  4080. ):
  4081. # Reference: aten/src/ATen/native/DilatedMaxPool2d.cpp
  4082. def unpack(name, val):
  4083. torch._check(
  4084. len(val) in [1, 2],
  4085. lambda: f"max_pool2d: {name} must either be a single int, or a tuple of two ints",
  4086. )
  4087. H = val[0]
  4088. W = H if len(val) == 1 else val[1]
  4089. return H, W
  4090. kH, kW = unpack("kernel_size", kernel_size)
  4091. torch._check(
  4092. len(stride) in [0, 1, 2],
  4093. lambda: "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints",
  4094. )
  4095. if len(stride) == 0:
  4096. dH, dW = kH, kW
  4097. else:
  4098. dH, dW = unpack("stride", stride)
  4099. padH, padW = unpack("padding", padding)
  4100. dilationH, dilationW = unpack("dilation", dilation)
  4101. nInputPlane = input.size(-3)
  4102. inputHeight = input.size(-2)
  4103. inputWidth = input.size(-1)
  4104. memory_format = utils.suggest_memory_format(input)
  4105. if memory_format == torch.channels_last:
  4106. torch._check(
  4107. input.dim() == 4,
  4108. lambda: "non-empty 4D (batch mode) tensor expected for input with channels_last layout",
  4109. )
  4110. elif memory_format == torch.contiguous_format:
  4111. torch._check(
  4112. input.dim() in [3, 4],
  4113. lambda: "non-empty 3D or 4D (batch mode) tensor expected for input",
  4114. )
  4115. else:
  4116. torch._check(
  4117. False,
  4118. lambda: "Unsupported memory format. Supports only ChannelsLast, Contiguous",
  4119. )
  4120. outputHeight = pooling_output_shape(inputHeight, kH, padH, dH, dilationH, ceil_mode)
  4121. outputWidth = pooling_output_shape(inputWidth, kW, padW, dW, dilationW, ceil_mode)
  4122. pool2d_shape_check(
  4123. input,
  4124. kH,
  4125. kW,
  4126. dH,
  4127. dW,
  4128. padH,
  4129. padW,
  4130. dilationH,
  4131. dilationW,
  4132. nInputPlane,
  4133. inputHeight,
  4134. inputWidth,
  4135. outputHeight,
  4136. outputWidth,
  4137. memory_format,
  4138. )
  4139. return nInputPlane, outputHeight, outputWidth
  4140. @register_meta(aten.max_pool2d_with_indices_backward.default)
  4141. def meta_max_pool2d_with_indices_backward(
  4142. grad_output,
  4143. self,
  4144. kernel_size,
  4145. stride,
  4146. padding,
  4147. dilation,
  4148. ceil_mode,
  4149. indices,
  4150. ):
  4151. (
  4152. nInputPlane,
  4153. outputHeight,
  4154. outputWidth,
  4155. ) = max_pool2d_checks_and_compute_shape(
  4156. self, kernel_size, stride, padding, dilation, ceil_mode
  4157. )
  4158. torch._check(
  4159. self.dtype == grad_output.dtype,
  4160. lambda: f"Expected dtype {self.dtype} for `gradOutput` but got dtype {grad_output.dtype}",
  4161. )
  4162. nOutputPlane = nInputPlane
  4163. ndim = self.ndim
  4164. def _check_dim_size(t):
  4165. check_dim_size(t, ndim, ndim - 3, nOutputPlane)
  4166. check_dim_size(t, ndim, ndim - 2, outputHeight)
  4167. check_dim_size(t, ndim, ndim - 1, outputWidth)
  4168. _check_dim_size(grad_output)
  4169. _check_dim_size(indices)
  4170. memory_format = utils.suggest_memory_format(self)
  4171. return torch.empty(
  4172. self.shape,
  4173. dtype=self.dtype,
  4174. device=self.device,
  4175. memory_format=memory_format,
  4176. )
  4177. @register_meta(aten.max_pool2d_with_indices.default)
  4178. def meta_max_pool2d_with_indices(
  4179. input,
  4180. kernel_size,
  4181. stride=(),
  4182. padding=(0,),
  4183. dilation=(1,),
  4184. ceil_mode=False,
  4185. ):
  4186. (
  4187. nInputPlane,
  4188. outputHeight,
  4189. outputWidth,
  4190. ) = max_pool2d_checks_and_compute_shape(
  4191. input, kernel_size, stride, padding, dilation, ceil_mode
  4192. )
  4193. nbatch = input.size(-4) if input.dim() == 4 else 1
  4194. memory_format = utils.suggest_memory_format(input)
  4195. if input.dim() == 3:
  4196. size = [nInputPlane, outputHeight, outputWidth]
  4197. else:
  4198. size = [nbatch, nInputPlane, outputHeight, outputWidth]
  4199. return (
  4200. torch.empty(
  4201. size,
  4202. dtype=input.dtype,
  4203. device=input.device,
  4204. memory_format=memory_format,
  4205. ),
  4206. torch.empty(
  4207. size,
  4208. dtype=torch.int64,
  4209. device=input.device,
  4210. memory_format=memory_format,
  4211. ),
  4212. )
  4213. @register_meta(aten.fractional_max_pool2d.default)
  4214. def meta_fractional_max_pool2d(self, kernel_size, output_size, random_samples):
  4215. torch._check(
  4216. self.ndim in (3, 4),
  4217. lambda: f"fractional_max_pool2d: Expected 3D or 4D tensor, but got: {self.ndim}",
  4218. )
  4219. ndim = self.ndim
  4220. for d in range(ndim - 3, ndim):
  4221. torch._check(
  4222. self.size(d) > 0,
  4223. lambda: f"fractional_max_pool2d: Expected input to have non-zero "
  4224. f" size for non-batch dimensions, but got {self.size()} with dimension {d} empty",
  4225. )
  4226. # the check and message are out of sync, but this matches the structured meta
  4227. torch._check(
  4228. len(kernel_size) == 2,
  4229. lambda: "fractional_max_pool2d: kernel_size must"
  4230. "either be a single int or tuple of Ints",
  4231. )
  4232. torch._check(
  4233. len(output_size) == 2,
  4234. lambda: "fractional_max_pool2d: output_size must "
  4235. "either be a single int or tuple of Ints",
  4236. )
  4237. input_channels = self.size(-3)
  4238. input_height = self.size(-2)
  4239. input_width = self.size(-1)
  4240. if ndim == 4:
  4241. input_batch = self.size(0)
  4242. else:
  4243. input_batch = 1
  4244. torch._check(
  4245. self.dtype == random_samples.dtype,
  4246. lambda: "Expect _random_samples to have the same dtype as input",
  4247. )
  4248. torch._check(
  4249. random_samples.ndim == 3,
  4250. lambda: f"Expect _random samples to have 3 dimensions got, {random_samples.ndim}",
  4251. )
  4252. n = random_samples.size(0)
  4253. c = random_samples.size(1)
  4254. d = random_samples.size(2)
  4255. torch._check(
  4256. n >= input_batch,
  4257. lambda: "Expect _random_samples.size(0) no less then input batch size.",
  4258. )
  4259. torch._check(
  4260. c == input_channels,
  4261. lambda: "Expect _random_samples.size(1) equals to input channel size.",
  4262. )
  4263. torch._check(d == 2, lambda: f"Expect _random_samples.size(2) equals to 2 got {d}.")
  4264. torch._check(
  4265. output_size[0] + kernel_size[0] - 1 <= input_height,
  4266. lambda: f"fractional_max_pool2d: kernel height {kernel_size[0]} is too large relative to input height {input_height}",
  4267. )
  4268. torch._check(
  4269. output_size[1] + kernel_size[1] - 1 <= input_width,
  4270. lambda: f"fractional_max_pool2d: kernel width {kernel_size[1]} is too large relative to input width {input_width}",
  4271. )
  4272. if self.dim() == 4:
  4273. size = [input_batch, input_channels, output_size[0], output_size[1]]
  4274. else:
  4275. size = [input_channels, output_size[0], output_size[1]]
  4276. return (
  4277. torch.empty(
  4278. size,
  4279. dtype=self.dtype,
  4280. device=self.device,
  4281. ),
  4282. torch.empty(
  4283. size,
  4284. dtype=torch.int64,
  4285. device=self.device,
  4286. ),
  4287. )
  4288. @register_meta(aten.max_pool3d_with_indices)
  4289. @out_wrapper("out", "indices")
  4290. def meta_max_pool3d_with_indices(
  4291. input,
  4292. kernel_size,
  4293. stride=(),
  4294. padding=(0,),
  4295. dilation=(1,),
  4296. ceil_mode=False,
  4297. ):
  4298. torch._check(
  4299. len(kernel_size) in (1, 3),
  4300. lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
  4301. )
  4302. kT = kernel_size[0]
  4303. kH = kT if len(kernel_size) == 1 else kernel_size[1]
  4304. kW = kT if len(kernel_size) == 1 else kernel_size[2]
  4305. torch._check(
  4306. not stride or len(stride) in (1, 3),
  4307. lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
  4308. )
  4309. dT = kT if not stride else stride[0]
  4310. dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
  4311. dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
  4312. torch._check(
  4313. len(padding) in (1, 3),
  4314. lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
  4315. )
  4316. pT = padding[0]
  4317. pH = pT if len(padding) == 1 else padding[1]
  4318. pW = pT if len(padding) == 1 else padding[2]
  4319. torch._check(
  4320. len(dilation) in (1, 3),
  4321. lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
  4322. )
  4323. dilationT = dilation[0]
  4324. dilationH = dilationT if len(dilation) == 1 else dilation[1]
  4325. dilationW = dilationT if len(dilation) == 1 else dilation[2]
  4326. torch._check(
  4327. input.ndim in (4, 5),
  4328. lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
  4329. )
  4330. nbatch = input.size(-5) if input.ndim == 5 else 1
  4331. nslices = input.size(-4)
  4332. itime = input.size(-3)
  4333. iheight = input.size(-2)
  4334. iwidth = input.size(-1)
  4335. otime = pooling_output_shape(itime, kT, pT, dT, dilationT, ceil_mode)
  4336. oheight = pooling_output_shape(iheight, kH, pH, dH, dilationH, ceil_mode)
  4337. owidth = pooling_output_shape(iwidth, kW, pW, dW, dilationW, ceil_mode)
  4338. pool3d_shape_check(
  4339. input,
  4340. nslices,
  4341. kT,
  4342. kH,
  4343. kW,
  4344. dT,
  4345. dH,
  4346. dW,
  4347. pT,
  4348. pH,
  4349. pW,
  4350. dilationT,
  4351. dilationH,
  4352. dilationW,
  4353. itime,
  4354. iheight,
  4355. iwidth,
  4356. otime,
  4357. oheight,
  4358. owidth,
  4359. "max_pool3d_with_indices()",
  4360. )
  4361. channels_last = (
  4362. input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
  4363. )
  4364. if input.ndim == 4:
  4365. input_channels_last_check = input.unsqueeze(0)
  4366. channels_last = (
  4367. not input_channels_last_check.is_contiguous()
  4368. ) and input_channels_last_check.is_contiguous(
  4369. memory_format=torch.channels_last_3d
  4370. )
  4371. out_shape = (nslices, otime, oheight, owidth)
  4372. else:
  4373. out_shape = (nbatch, nslices, otime, oheight, owidth) # type: ignore[assignment]
  4374. out = input.new_empty(out_shape)
  4375. indices = input.new_empty(out_shape, dtype=torch.int64)
  4376. if channels_last:
  4377. out = out.to(memory_format=torch.channels_last_3d)
  4378. indices = indices.to(memory_format=torch.channels_last_3d)
  4379. return out, indices
  4380. @register_meta(aten.max_pool3d_with_indices_backward)
  4381. @out_wrapper("grad_input")
  4382. def meta_max_pool3d_with_indices_backward(
  4383. grad_output,
  4384. input,
  4385. kernel_size,
  4386. stride,
  4387. padding,
  4388. dilation,
  4389. ceil_mode,
  4390. indices,
  4391. ):
  4392. torch._check(
  4393. len(kernel_size) in (1, 3),
  4394. lambda: "max_pool3d: kernel_size must either be a single int, or a tuple of three ints",
  4395. )
  4396. kT = kernel_size[0]
  4397. kH = kT if len(kernel_size) == 1 else kernel_size[1]
  4398. kW = kT if len(kernel_size) == 1 else kernel_size[2]
  4399. torch._check(
  4400. not stride or len(stride) in (1, 3),
  4401. lambda: "max_pool3d: stride must either be omitted, a single int, or a tuple of three ints",
  4402. )
  4403. dT = kT if not stride else stride[0]
  4404. dH = kH if not stride else (dT if len(stride) == 1 else stride[1])
  4405. dW = kW if not stride else (dT if len(stride) == 1 else stride[2])
  4406. torch._check(
  4407. len(padding) in (1, 3),
  4408. lambda: "max_pool3d: padding must either be a single int, or a tuple of three ints",
  4409. )
  4410. pT = padding[0]
  4411. pH = pT if len(padding) == 1 else padding[1]
  4412. pW = pT if len(padding) == 1 else padding[2]
  4413. torch._check(
  4414. len(dilation) in (1, 3),
  4415. lambda: "max_pool3d: dilation must be either a single int, or a tuple of three ints",
  4416. )
  4417. dilationT = dilation[0]
  4418. dilationH = dilationT if len(dilation) == 1 else dilation[1]
  4419. dilationW = dilationT if len(dilation) == 1 else dilation[2]
  4420. torch._check(
  4421. input.ndim in (4, 5),
  4422. lambda: "non-empty 4D or 5D (batch mode) tensor expected for input",
  4423. )
  4424. nslices = input.size(-4)
  4425. itime = input.size(-3)
  4426. iheight = input.size(-2)
  4427. iwidth = input.size(-1)
  4428. otime = grad_output.size(-3)
  4429. oheight = grad_output.size(-2)
  4430. owidth = grad_output.size(-1)
  4431. max_pool3d_backward_shape_check(
  4432. input,
  4433. grad_output,
  4434. indices,
  4435. nslices,
  4436. kT,
  4437. kH,
  4438. kW,
  4439. dT,
  4440. dH,
  4441. dW,
  4442. pT,
  4443. pH,
  4444. pW,
  4445. dilationT,
  4446. dilationH,
  4447. dilationW,
  4448. itime,
  4449. iheight,
  4450. iwidth,
  4451. otime,
  4452. oheight,
  4453. owidth,
  4454. "max_pool3d_with_indices_backward()",
  4455. )
  4456. channels_last = (
  4457. input.ndim == 5 and utils.suggest_memory_format(input) == torch.channels_last_3d
  4458. )
  4459. if input.ndim == 4:
  4460. input_channels_last_check = input.unsqueeze(0)
  4461. channels_last = (
  4462. not input_channels_last_check.is_contiguous()
  4463. ) and input_channels_last_check.is_contiguous(
  4464. memory_format=torch.channels_last_3d
  4465. )
  4466. grad_input = input.new_empty(input.shape)
  4467. if channels_last:
  4468. grad_input = grad_input.to(memory_format=torch.channels_last_3d)
  4469. return grad_input
  4470. def check_grid_sampler_common(input: Tensor, grid: Tensor):
  4471. torch._check(
  4472. input.device == grid.device,
  4473. lambda: (
  4474. f"grid_sampler(): expected input and grid to be on same device, but input "
  4475. f"is on {input.device} and grid is on {grid.device}"
  4476. ),
  4477. )
  4478. torch._check(
  4479. input.layout == torch.strided and grid.layout == torch.strided,
  4480. lambda: (
  4481. f"grid_sampler(): expected input and grid to have torch.strided layout, but "
  4482. f"input has {input.layout} and grid has {grid.layout}"
  4483. ),
  4484. )
  4485. torch._check(
  4486. input.shape[0] == grid.shape[0],
  4487. lambda: (
  4488. f"grid_sampler(): expected grid and input to have same batch size, but got "
  4489. f"input with sizes {input.shape} and grid with sizes {grid.shape}"
  4490. ),
  4491. )
  4492. torch._check(
  4493. grid.shape[-1] == input.ndim - 2,
  4494. lambda: (
  4495. f"grid_sampler(): expected grid to have size {input.ndim - 2} in last "
  4496. f"dimension, but got grid with sizes {grid.shape}"
  4497. ),
  4498. )
  4499. for i in range(2, input.ndim):
  4500. torch._check(
  4501. input.shape[i] > 0,
  4502. lambda: (
  4503. f"grid_sampler(): expected input to have non-empty spatial dimensions, "
  4504. f"but input has sizes {input.shape} with dimension {i} being empty"
  4505. ),
  4506. )
  4507. class GridSamplerInterpolation(Enum):
  4508. BILINEAR = 0
  4509. NEAREST = 1
  4510. BICUBIC = 2
  4511. def check_grid_sampler_3d(input: Tensor, grid: Tensor, interpolation_mode: int):
  4512. torch._check(
  4513. input.ndim == 5 and input.ndim == grid.ndim,
  4514. lambda: (
  4515. f"grid_sampler(): expected 5D input and grid with same number of "
  4516. f"dimensions, but got input with sizes {input.shape}"
  4517. f" and grid with sizes {grid.shape}"
  4518. ),
  4519. )
  4520. torch._check(
  4521. not (
  4522. input.ndim == 5
  4523. and interpolation_mode == GridSamplerInterpolation.BICUBIC.value
  4524. ),
  4525. lambda: "grid_sampler(): bicubic interpolation only supports 4D input",
  4526. )
  4527. @register_meta(aten.grid_sampler_2d_backward.default)
  4528. def grid_sampler_2d_backward_meta(
  4529. grad_output,
  4530. input,
  4531. grid,
  4532. interpolation_mode,
  4533. padding_mode,
  4534. align_corners,
  4535. output_mask,
  4536. ):
  4537. input_requires_grad = output_mask[0]
  4538. if input_requires_grad:
  4539. grad_input = torch.zeros_like(input, memory_format=torch.contiguous_format)
  4540. else:
  4541. grad_input = None
  4542. grad_grid = torch.empty_like(grid, memory_format=torch.contiguous_format)
  4543. return (grad_input, grad_grid)
  4544. @register_meta(aten.grid_sampler_3d)
  4545. @out_wrapper()
  4546. def grid_sampler_3d(
  4547. input,
  4548. grid,
  4549. interpolation_mode,
  4550. padding_mode,
  4551. align_corners,
  4552. ):
  4553. check_grid_sampler_common(input, grid)
  4554. check_grid_sampler_3d(input, grid, interpolation_mode)
  4555. N = input.shape[0]
  4556. C = input.shape[1]
  4557. out_D = grid.shape[1]
  4558. out_H = grid.shape[2]
  4559. out_W = grid.shape[3]
  4560. return input.new_empty((N, C, out_D, out_H, out_W))
  4561. @register_meta(aten.grid_sampler_3d_backward)
  4562. @out_wrapper("grad_input", "grad_grid")
  4563. def grid_sampler_3d_backward(
  4564. grad_output,
  4565. input,
  4566. grid,
  4567. interpolation_mode,
  4568. padding_mode,
  4569. align_corners,
  4570. output_mask,
  4571. ):
  4572. check_grid_sampler_common(input, grid)
  4573. check_grid_sampler_3d(input, grid, interpolation_mode)
  4574. input_requires_grad = output_mask[0]
  4575. if input_requires_grad:
  4576. grad_input = torch.zeros_like(
  4577. input, memory_format=torch.legacy_contiguous_format
  4578. )
  4579. else:
  4580. grad_input = None
  4581. grad_grid = torch.empty_like(grid, memory_format=torch.legacy_contiguous_format)
  4582. return grad_input, grad_grid
  4583. @register_meta([aten.full.default])
  4584. def full(size, fill_value, *args, **kwargs):
  4585. dtype = kwargs.get("dtype")
  4586. if not dtype:
  4587. dtype = utils.get_dtype(fill_value)
  4588. kwargs["dtype"] = dtype
  4589. # pyrefly: ignore [not-iterable]
  4590. return torch.empty(size, *args, **kwargs)
  4591. # zeros_like is special cased to work for sparse
  4592. @register_meta(aten.zeros_like.default)
  4593. def zeros_like(
  4594. self,
  4595. dtype=None,
  4596. layout=None,
  4597. device=None,
  4598. pin_memory=None,
  4599. memory_format=None,
  4600. ):
  4601. if layout == torch.sparse_coo:
  4602. torch._check(
  4603. memory_format is None,
  4604. lambda: "memory format option is only supported by strided tensors",
  4605. )
  4606. res = torch.empty(
  4607. 0,
  4608. dtype=self.dtype if dtype is None else dtype,
  4609. layout=layout,
  4610. device=self.device if device is None else device,
  4611. pin_memory=pin_memory,
  4612. )
  4613. if self.is_sparse:
  4614. res.sparse_resize_and_clear_(
  4615. self.size(), self.sparse_dim(), self.dense_dim()
  4616. )
  4617. else:
  4618. res.sparse_resize_and_clear_(self.size(), self.dim(), 0)
  4619. res._coalesced_(True)
  4620. return res
  4621. res = aten.empty_like.default(
  4622. self,
  4623. dtype=dtype,
  4624. layout=layout,
  4625. device=device,
  4626. pin_memory=pin_memory,
  4627. memory_format=memory_format,
  4628. )
  4629. # device can be not "meta"
  4630. res.fill_(0)
  4631. return res
  4632. @register_meta([aten.ones.default, aten.ones.out])
  4633. @out_wrapper()
  4634. def meta_ones(
  4635. size,
  4636. *,
  4637. dtype=None,
  4638. layout=None,
  4639. device=None,
  4640. pin_memory=None,
  4641. requires_grad=False,
  4642. ):
  4643. if dtype is None:
  4644. dtype = torch.get_default_dtype()
  4645. if device is None:
  4646. device = torch.get_default_device()
  4647. if layout is None:
  4648. layout = torch.strided
  4649. return torch.empty(
  4650. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  4651. )
  4652. @register_meta([aten.zeros.default, aten.zeros.out])
  4653. @out_wrapper()
  4654. def meta_zeros(
  4655. size,
  4656. *,
  4657. dtype=None,
  4658. layout=None,
  4659. device=None,
  4660. pin_memory=None,
  4661. requires_grad=False,
  4662. ):
  4663. if dtype is None:
  4664. dtype = torch.get_default_dtype()
  4665. if device is None:
  4666. device = torch.get_default_device()
  4667. if layout is None:
  4668. layout = torch.strided
  4669. return torch.empty(
  4670. size, dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  4671. )
  4672. @register_meta(aten.select_scatter.default)
  4673. def meta_select_scatter(self, src, dim, index):
  4674. return utils.clone_preserve_strides(self)
  4675. @register_meta(aten.slice_scatter.default)
  4676. def meta_slice_scatter(self, src, dim=0, start=None, end=None, step=1):
  4677. return utils.clone_preserve_strides(self)
  4678. # TODO: Deduplicate this with canonicalize_dim
  4679. def maybe_wrap_dim(dim: int, dim_post_expr: int, wrap_scalar: bool = True):
  4680. if dim_post_expr <= 0:
  4681. assert wrap_scalar
  4682. dim_post_expr = 1
  4683. min = -dim_post_expr
  4684. max = dim_post_expr - 1
  4685. assert not (dim < min or dim > max), f"dim {dim} out of bounds ({min}, {max})"
  4686. if dim < 0:
  4687. dim += dim_post_expr
  4688. return dim
  4689. def ensure_nonempty_size(t, dim):
  4690. return 1 if t.dim() == 0 else t.shape[dim]
  4691. # From aten/src/ATen/native/ScatterGatherChecks.h
  4692. def gather_shape_check(self, dim, index):
  4693. self_dims = max(self.dim(), 1)
  4694. index_dims = max(index.dim(), 1)
  4695. torch._check(
  4696. self_dims == index_dims,
  4697. lambda: "Index tensor must have the same number of dimensions as input tensor",
  4698. )
  4699. for i in range(self_dims):
  4700. if i != dim:
  4701. torch._check(
  4702. ensure_nonempty_size(index, i) <= ensure_nonempty_size(self, i),
  4703. lambda: f"Size does not match at dimension {i} expected index {index.shape}"
  4704. + f" to be no larger than self {self.shape} apart from dimension {dim}",
  4705. )
  4706. @register_meta(aten.gather.default)
  4707. def meta_gather(self, dim, index, sparse_grad=False):
  4708. from torch.fx.experimental.symbolic_shapes import guard_or_false
  4709. wrapped_dim = maybe_wrap_dim(dim, self.dim())
  4710. is_index_empty = guard_or_false(index.numel() == 0)
  4711. if not is_index_empty:
  4712. torch._check(
  4713. index.dtype == torch.long or index.dtype == torch.int,
  4714. lambda: f"gather(): Expected dtype int32/int64 for index, but got {index.dtype}",
  4715. )
  4716. gather_shape_check(self, wrapped_dim, index)
  4717. return self.new_empty(index.shape)
  4718. # From aten/src/ATen/native/TensorAdvancedIndexing.cpp
  4719. def get_operator_enum(reduce_, use_new_options=False):
  4720. if use_new_options:
  4721. if reduce_ == "sum":
  4722. return "REDUCE_ADD"
  4723. elif reduce_ == "prod":
  4724. return "REDUCE_MULTIPLY"
  4725. elif reduce_ == "mean":
  4726. return "REDUCE_MEAN"
  4727. elif reduce_ == "amax":
  4728. return "REDUCE_MAXIMUM"
  4729. elif reduce_ == "amin":
  4730. return "REDUCE_MINIMUM"
  4731. torch._check(
  4732. False,
  4733. lambda: "reduce argument must be either sum, prod, mean, amax or amin.",
  4734. )
  4735. return
  4736. else:
  4737. if reduce_ == "add":
  4738. return "REDUCE_ADD"
  4739. elif reduce_ == "multiply":
  4740. return "REDUCE_MULTIPLY"
  4741. torch._check(False, lambda: "reduce argument must be either add or multiply.")
  4742. return
  4743. # From aten/src/ATen/native/ScatterGatherChecks.h
  4744. def scatter_gather_dtype_check(method_name, self, index, src_opt=None):
  4745. from torch.fx.experimental.symbolic_shapes import guard_or_true
  4746. if guard_or_true(index.numel() != 0):
  4747. torch._check(
  4748. index.dtype == torch.long or index.dtype == torch.int,
  4749. lambda: f"{method_name}(): Expected dtype int32/int64 for index",
  4750. )
  4751. if src_opt is not None:
  4752. torch._check(
  4753. self.dtype == src_opt.dtype,
  4754. lambda: f"{method_name}(): Expected self.dtype to be equal to src.dtype",
  4755. )
  4756. def ensure_nonempty_dim(dim):
  4757. return max(dim, 1)
  4758. # From aten/src/ATen/native/ScatterGatherChecks.h
  4759. def scatter_shape_check(self, dim, index, src_opt=None):
  4760. from torch.fx.experimental.symbolic_shapes import guard_or_false
  4761. if guard_or_false(index.numel() == 0):
  4762. return
  4763. torch._check(
  4764. ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
  4765. lambda: "Index tensor must have the same number of dimensions as self tensor",
  4766. )
  4767. is_wrong_shape = False
  4768. self_dims = ensure_nonempty_dim(self.dim())
  4769. # Check: index.size(d) <= self.size(d) for all d != dim
  4770. for d in range(self_dims):
  4771. index_d_size = ensure_nonempty_size(index, d)
  4772. if d == dim:
  4773. continue
  4774. if index_d_size > ensure_nonempty_size(self, d):
  4775. is_wrong_shape = True
  4776. break
  4777. # Check: index.size(d) <= src.size(d) for all d if src is Tensor
  4778. if not is_wrong_shape and src_opt is not None:
  4779. for d in range(self_dims):
  4780. index_d_size = ensure_nonempty_size(index, d)
  4781. if index_d_size > ensure_nonempty_size(src_opt, d):
  4782. is_wrong_shape = True
  4783. break
  4784. if src_opt is not None:
  4785. torch._check(
  4786. ensure_nonempty_dim(self.dim()) == ensure_nonempty_dim(index.dim()),
  4787. lambda: "Index tensor must have the same number of dimensions as self tensor",
  4788. )
  4789. torch._check(
  4790. not is_wrong_shape,
  4791. lambda: f"Expected index {index.shape} to be no larger than self {self.shape}"
  4792. + f" apart from dimension {dim} and to be no larger than src {src_opt.shape}",
  4793. )
  4794. else:
  4795. torch._check(
  4796. not is_wrong_shape,
  4797. lambda: f"Expected index {index.shape} to be no larger than self {self.shape}"
  4798. + f" apart from dimension {dim}",
  4799. )
  4800. # From aten/src/ATen/native/TensorAdvancedIndexing.cpp
  4801. def scatter_meta_impl(self, dim, index, src=None, reduce_=None, use_new_options=False):
  4802. wrapped_dim = maybe_wrap_dim(dim, self.dim())
  4803. scatter_gather_dtype_check("scatter", self, index, src)
  4804. scatter_shape_check(self, wrapped_dim, index, src)
  4805. if reduce_ is not None:
  4806. # Check if we have a valid reduce operator.
  4807. get_operator_enum(reduce_, use_new_options)
  4808. @register_meta(aten.scatter_add.default)
  4809. def meta_scatter_add(self, dim, index, src):
  4810. scatter_meta_impl(self, dim, index, src, "add")
  4811. return self.new_empty(self.shape)
  4812. @register_meta(aten.scatter_add_)
  4813. def meta_scatter_add_(self, dim, index, src):
  4814. scatter_meta_impl(self, dim, index, src, "add")
  4815. return self
  4816. @register_meta(
  4817. [
  4818. aten.scatter.src,
  4819. aten.scatter.value,
  4820. aten.scatter.reduce,
  4821. aten.scatter.value_reduce,
  4822. ]
  4823. )
  4824. @out_wrapper()
  4825. def meta_scatter(self, dim, index, src_or_value, reduce=None):
  4826. src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
  4827. scatter_meta_impl(self, dim, index, src, reduce)
  4828. return self.new_empty(self.shape)
  4829. @register_meta(
  4830. [
  4831. aten.scatter_.src,
  4832. aten.scatter_.value,
  4833. aten.scatter_.reduce,
  4834. aten.scatter_.value_reduce,
  4835. ]
  4836. )
  4837. def meta_scatter_(self, dim, index, src_or_value, reduce=None):
  4838. src = src_or_value if isinstance(src_or_value, torch.Tensor) else None
  4839. scatter_meta_impl(self, dim, index, src, reduce)
  4840. return self
  4841. @register_meta([aten._scaled_dot_product_flash_attention])
  4842. def meta__scaled_dot_product_flash_attention(
  4843. query: Tensor,
  4844. key: Tensor,
  4845. value: Tensor,
  4846. dropout_p: float = 0.0,
  4847. is_causal: bool = False,
  4848. return_debug_mask: bool = False,
  4849. scale: float | None = None,
  4850. ):
  4851. batch_size = query.size(0)
  4852. num_heads = query.size(1)
  4853. max_seqlen_batch_q = query.size(2)
  4854. head_dim = query.size(3)
  4855. max_seqlen_batch_k = key.size(2)
  4856. attention = torch.empty_like(query)
  4857. logsumexp = torch.empty(
  4858. (batch_size, num_heads, max_seqlen_batch_q),
  4859. dtype=torch.float,
  4860. device=query.device,
  4861. )
  4862. if return_debug_mask:
  4863. blocksize_c = 128 if head_dim > 64 else 256
  4864. max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
  4865. if max_seqlen_batch_k <= 128:
  4866. max_seqlen_k = 128
  4867. elif max_seqlen_batch_k <= 256:
  4868. max_seqlen_k = 256
  4869. debug_mask = torch.empty(
  4870. (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
  4871. dtype=query.dtype,
  4872. device=query.device,
  4873. )
  4874. else:
  4875. debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
  4876. # Note [Seed and Offset]: device for seed and offset below depends on whether we are
  4877. # capturing or not, but at the time of tracing we don't know if we
  4878. # are going to use cudagraphs or not, so we return meta tensors here
  4879. # it's possible we'll need to have some special handling in inductor for sdpa
  4880. # See [Note] BC breaking change to flash seed/offset
  4881. if torch.version.hip and torch.cuda.is_available() or device_hint(query) == "xpu":
  4882. # Maintain old path on AMD
  4883. seed = torch.empty((), dtype=torch.long, device="meta")
  4884. offset = torch.empty((), dtype=torch.long, device="meta")
  4885. else:
  4886. seed = torch.empty((2), dtype=torch.uint64, device="meta")
  4887. offset = torch.empty((), dtype=torch.uint64, device="meta")
  4888. return (
  4889. attention,
  4890. logsumexp,
  4891. None,
  4892. None,
  4893. max_seqlen_batch_q,
  4894. max_seqlen_batch_k,
  4895. seed,
  4896. offset,
  4897. debug_mask,
  4898. )
  4899. def alloc_with_matching_layout(
  4900. query: Tensor,
  4901. res_shape: tuple[int, ...],
  4902. ):
  4903. if tuple(query.shape) == res_shape:
  4904. res = torch.empty_like(query)
  4905. else:
  4906. dim_order = sorted(
  4907. [0, 1, 2, 3], key=lambda idx: query.stride()[idx], reverse=True
  4908. )
  4909. permuted_shape = [res_shape[idx] for idx in dim_order]
  4910. final_permute = [dim_order.index(i) for i in range(len(dim_order))]
  4911. res = torch.empty(
  4912. permuted_shape, dtype=query.dtype, device=query.device
  4913. ).permute(final_permute)
  4914. return res
  4915. @register_meta([aten._scaled_dot_product_cudnn_attention])
  4916. def meta__scaled_dot_product_cudnn_attention(
  4917. query: Tensor,
  4918. key: Tensor,
  4919. value: Tensor,
  4920. attn_bias: Tensor | None,
  4921. compute_log_sumexp: bool,
  4922. dropout_p: float = 0.0,
  4923. is_causal: bool = False,
  4924. return_debug_mask: bool = False,
  4925. scale: float | None = None,
  4926. ):
  4927. B = query.size(0)
  4928. H = query.size(1)
  4929. S_Q = query.size(2)
  4930. S_KV = key.size(2)
  4931. D_V = value.size(-1)
  4932. res_shape = (B, H, S_Q, D_V)
  4933. res = alloc_with_matching_layout(query, res_shape)
  4934. logsum_exp = torch.empty(
  4935. (B, H, S_Q, 1),
  4936. dtype=torch.float,
  4937. device=query.device,
  4938. )
  4939. # See Note [Seed and Offset]
  4940. seed = torch.empty((), dtype=torch.long, device="meta")
  4941. offset = torch.empty((), dtype=torch.long, device="meta")
  4942. return (
  4943. res,
  4944. logsum_exp,
  4945. None,
  4946. None,
  4947. S_Q,
  4948. S_KV,
  4949. seed,
  4950. offset,
  4951. None,
  4952. )
  4953. @register_meta([aten._scaled_dot_product_fused_attention_overrideable])
  4954. def meta__scaled_dot_product_fused_attention_overrideable(
  4955. query: Tensor,
  4956. key: Tensor,
  4957. value: Tensor,
  4958. attn_bias: Tensor | None = None,
  4959. dropout_p: float = 0.0,
  4960. is_causal: bool = False,
  4961. return_debug_mask: bool = False,
  4962. scale: float | None = None,
  4963. ):
  4964. B = query.size(0)
  4965. H_Q = query.size(1)
  4966. S_Q = query.size(2)
  4967. S_KV = key.size(2)
  4968. D_V = value.size(-1)
  4969. res_shape = (B, H_Q, S_Q, D_V)
  4970. res = alloc_with_matching_layout(query, res_shape)
  4971. logsum_exp = torch.empty(
  4972. (B, H_Q, S_Q),
  4973. dtype=torch.float,
  4974. device=query.device,
  4975. )
  4976. # See Note [Seed and Offset]
  4977. seed = torch.empty((), dtype=torch.long, device="meta")
  4978. offset = torch.empty((), dtype=torch.long, device="meta")
  4979. return (
  4980. res,
  4981. logsum_exp,
  4982. None,
  4983. None,
  4984. S_Q,
  4985. S_KV,
  4986. seed,
  4987. offset,
  4988. None,
  4989. )
  4990. @register_meta(
  4991. [
  4992. aten._scaled_dot_product_flash_attention_backward,
  4993. ]
  4994. )
  4995. def meta__scaled_dot_product_flash_backward(
  4996. grad_out: Tensor,
  4997. query: Tensor,
  4998. key: Tensor,
  4999. value: Tensor,
  5000. out: Tensor,
  5001. logsumexp: Tensor,
  5002. cum_seq_q: Tensor,
  5003. cum_seq_k: Tensor,
  5004. max_q: int,
  5005. max_k: int,
  5006. dropout_p: float,
  5007. is_causal: bool,
  5008. philox_seed: Tensor,
  5009. philox_offset: Tensor,
  5010. scale: float | None = None,
  5011. ):
  5012. grad_q = torch.empty_like(query)
  5013. grad_k = torch.empty_like(key)
  5014. grad_v = torch.empty_like(value)
  5015. return grad_q, grad_k, grad_v
  5016. @register_meta(
  5017. [
  5018. aten._scaled_dot_product_flash_attention_for_cpu,
  5019. ]
  5020. )
  5021. def meta__scaled_dot_product_flash_attention_for_cpu(
  5022. query: Tensor,
  5023. key: Tensor,
  5024. value: Tensor,
  5025. dropout_p: float = 0.0,
  5026. is_causal: bool = False,
  5027. attn_mask: Tensor | None = None,
  5028. scale: float | None = None,
  5029. ):
  5030. batch_size = query.size(0)
  5031. num_heads = query.size(1)
  5032. max_seqlen_batch_q = query.size(2)
  5033. attention = torch.empty_like(query)
  5034. logsumexp = torch.empty(
  5035. (
  5036. batch_size,
  5037. max_seqlen_batch_q,
  5038. num_heads,
  5039. ),
  5040. dtype=torch.float,
  5041. device=query.device,
  5042. ).transpose(1, 2)
  5043. return (
  5044. attention,
  5045. logsumexp,
  5046. )
  5047. @register_meta(
  5048. [
  5049. aten._scaled_dot_product_flash_attention_for_cpu_backward,
  5050. ]
  5051. )
  5052. def meta__scaled_dot_product_flash_attention_for_cpu_backward(
  5053. grad_out: Tensor,
  5054. query: Tensor,
  5055. key: Tensor,
  5056. value: Tensor,
  5057. out: Tensor,
  5058. logsumexp: Tensor,
  5059. dropout_p: float,
  5060. is_causal: bool,
  5061. attn_mask: Tensor | None = None,
  5062. scale: float | None = None,
  5063. ):
  5064. # cpus's grad layout is different from cuda's,
  5065. # i.e. (batch_size, seq_len, num_heads, head_dim)
  5066. grad_q = torch.empty_permuted(
  5067. query.size(),
  5068. (0, 2, 1, 3),
  5069. dtype=query.dtype,
  5070. device=query.device,
  5071. )
  5072. grad_k = torch.empty_permuted(
  5073. key.size(),
  5074. (0, 2, 1, 3),
  5075. dtype=key.dtype,
  5076. device=key.device,
  5077. )
  5078. grad_v = torch.empty_permuted(
  5079. value.size(),
  5080. (0, 2, 1, 3),
  5081. dtype=value.dtype,
  5082. device=value.device,
  5083. )
  5084. return grad_q, grad_k, grad_v
  5085. @register_meta([aten._scaled_dot_product_attention_math_for_mps])
  5086. def meta__scaled_dot_product_attention_math_for_mps(
  5087. query: Tensor,
  5088. key: Tensor,
  5089. value: Tensor,
  5090. attn_mask: Tensor | None = None,
  5091. dropout_p: float = 0.0,
  5092. is_causal: bool = False,
  5093. dropout_mask: Tensor | None = None,
  5094. scale: float | None = None,
  5095. ) -> tuple[Tensor, Tensor]:
  5096. def ensure_4d(x):
  5097. if x.dim() == 3:
  5098. return x.unsqueeze(0), True
  5099. elif x.dim() > 4:
  5100. batch_size = 1
  5101. for i in range(x.dim() - 3):
  5102. batch_size *= x.shape[i]
  5103. return x.view(batch_size, x.size(-3), x.size(-2), x.size(-1)), True
  5104. else:
  5105. return x, False
  5106. q_, unsqueezed = ensure_4d(query)
  5107. k_, _ = ensure_4d(key)
  5108. v_, _ = ensure_4d(value)
  5109. batch_size, num_head, q_size, head_size = q_.shape
  5110. _, k_size, max_seq_length, _ = k_.shape
  5111. def sdpa_vector_fast_mps():
  5112. out = q_.new_empty(q_.shape)
  5113. if unsqueezed:
  5114. out = out.view_as(query)
  5115. attn = q_.new_empty((batch_size, num_head, q_size, max_seq_length))
  5116. if unsqueezed:
  5117. if query.dim() == 3:
  5118. attn = attn.squeeze(0)
  5119. else:
  5120. shape = list(query.shape[:-3]) + attn.shape[1:4]
  5121. attn = attn.view(shape)
  5122. return out, attn
  5123. def sdpa_vector_2pass_mps():
  5124. blocks = 32
  5125. out = q_.new_empty(q_.shape)
  5126. intermediate = q_.new_empty((batch_size, num_head, q_size, blocks, head_size))
  5127. return out, intermediate
  5128. if (max_seq_length >= 1024) or (k_size < q_size and max_seq_length >= 4096):
  5129. return sdpa_vector_2pass_mps()
  5130. else:
  5131. return sdpa_vector_fast_mps()
  5132. @register_meta([aten._scaled_dot_product_efficient_attention])
  5133. def meta__scaled_dot_product_efficient_attention(
  5134. query: Tensor,
  5135. key: Tensor,
  5136. value: Tensor,
  5137. attn_bias: Tensor | None,
  5138. compute_log_sumexp: bool,
  5139. dropout_p=0.0,
  5140. is_causal: bool = False,
  5141. scale: float | None = None,
  5142. ):
  5143. query = query.transpose(1, 2)
  5144. key = key.transpose(1, 2)
  5145. value = value.transpose(1, 2)
  5146. B = query.size(0)
  5147. M = query.size(1)
  5148. num_heads = query.size(-2)
  5149. Kv = value.size(-1)
  5150. res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
  5151. if torch.version.hip and torch.cuda.is_available():
  5152. """Please see: https://github.com/pytorch/pytorch/issues/146848
  5153. longsumexp last dim should be seq length
  5154. """
  5155. logsumexp_dim = M if compute_log_sumexp else 0
  5156. else:
  5157. logsumexp_dim = math.ceil(M / 32) * 32 if compute_log_sumexp else 0
  5158. logsum_exp = torch.empty(
  5159. (B, num_heads, logsumexp_dim),
  5160. dtype=torch.float,
  5161. device=query.device,
  5162. )
  5163. res = res.transpose(1, 2)
  5164. # See Note [Seed and Offset]:
  5165. seed = torch.empty((), dtype=torch.long, device="meta")
  5166. offset = torch.empty((), dtype=torch.long, device="meta")
  5167. return res, logsum_exp, seed, offset
  5168. @register_meta(
  5169. [
  5170. aten._scaled_dot_product_efficient_attention_backward,
  5171. ]
  5172. )
  5173. def meta__scaled_dot_product_efficient_backward(
  5174. grad_out: Tensor,
  5175. query: Tensor,
  5176. key: Tensor,
  5177. value: Tensor,
  5178. attn_bias: Tensor | None,
  5179. out: Tensor,
  5180. logsumexp: Tensor,
  5181. philox_seed: Tensor,
  5182. philox_offset: Tensor,
  5183. dropout_p: float,
  5184. grad_input_mask: list[bool],
  5185. is_causal: bool = False,
  5186. scale: float | None = None,
  5187. ):
  5188. batch_size = query.size(0)
  5189. num_heads = query.size(1)
  5190. max_q = query.size(2)
  5191. head_dim = query.size(3)
  5192. head_dim_v = value.size(3)
  5193. max_k = key.size(2)
  5194. grad_q = torch.empty_permuted(
  5195. (batch_size, num_heads, max_q, head_dim),
  5196. (0, 2, 1, 3),
  5197. dtype=query.dtype,
  5198. device=query.device,
  5199. )
  5200. grad_k = torch.empty_permuted(
  5201. (batch_size, num_heads, max_k, head_dim),
  5202. (0, 2, 1, 3),
  5203. dtype=key.dtype,
  5204. device=key.device,
  5205. )
  5206. grad_v = torch.empty_permuted(
  5207. (batch_size, num_heads, max_k, head_dim_v),
  5208. (0, 2, 1, 3),
  5209. dtype=value.dtype,
  5210. device=value.device,
  5211. )
  5212. grad_bias = None
  5213. if attn_bias is not None and grad_input_mask[3]:
  5214. lastDim = attn_bias.size(-1)
  5215. lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
  5216. new_sizes = list(attn_bias.size())
  5217. new_sizes[-1] = lastDimAligned
  5218. grad_bias = torch.empty(
  5219. new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
  5220. )
  5221. grad_bias = grad_bias[..., :lastDim]
  5222. return grad_q, grad_k, grad_v, grad_bias
  5223. @register_meta(
  5224. [
  5225. aten._scaled_dot_product_cudnn_attention_backward,
  5226. ]
  5227. )
  5228. def meta__scaled_dot_product_cudnn_backward(
  5229. grad_out: Tensor,
  5230. query: Tensor,
  5231. key: Tensor,
  5232. value: Tensor,
  5233. out: Tensor,
  5234. logsumexp: Tensor,
  5235. philox_seed: Tensor,
  5236. philox_offset: Tensor,
  5237. attn_bias: Tensor,
  5238. cum_seq_q: Tensor,
  5239. cum_seq_k: Tensor,
  5240. max_q: int,
  5241. max_k: int,
  5242. dropout_p: float,
  5243. is_causal: bool,
  5244. scale: float | None = None,
  5245. ):
  5246. grad_q = torch.empty_like(query)
  5247. grad_k = torch.empty_like(key)
  5248. grad_v = torch.empty_like(value)
  5249. return grad_q, grad_k, grad_v
  5250. @register_meta(
  5251. [
  5252. aten._flash_attention_forward,
  5253. ]
  5254. )
  5255. def meta__flash_attention_forward(
  5256. query: Tensor,
  5257. key: Tensor,
  5258. value: Tensor,
  5259. cum_seq_q: Tensor | None,
  5260. cum_seq_k: Tensor | None,
  5261. max_q: int,
  5262. max_k: int,
  5263. dropout_p: float,
  5264. is_causal: bool,
  5265. return_debug_mask: bool,
  5266. scale: float | None = None,
  5267. window_size_left: int | None = None,
  5268. window_size_right: int | None = None,
  5269. seqused_k: Tensor | None = None,
  5270. alibi_slopes: Tensor | None = None,
  5271. ):
  5272. # NB: there are two underlying paths:
  5273. # 1. normal dense path; expect 4D inputs of shape (batch_size, seqlen, num_heads, head_dim)
  5274. # 2. varseqlen path; expect 3D inputs of shape (total, num_heads, head_dim) where total
  5275. # includes all batch item sequences. cum_seq_q / cum_seq_k contain offsets into total
  5276. batch_size = query.size(0) if cum_seq_q is None else cum_seq_q.numel() - 1
  5277. max_seqlen_batch_q = query.size(1) if cum_seq_q is None else max_q
  5278. max_seqlen_batch_k = key.size(1) if cum_seq_k is None else max_k
  5279. num_heads = query.size(-2)
  5280. head_dim = query.size(-1)
  5281. # Cuda Path
  5282. attention = torch.empty_like(query)
  5283. if cum_seq_q is None:
  5284. logsumexp = torch.empty(
  5285. (batch_size, num_heads, max_seqlen_batch_q),
  5286. dtype=torch.float,
  5287. device=query.device,
  5288. )
  5289. else:
  5290. total_q = query.size(0)
  5291. logsumexp = torch.empty(
  5292. (num_heads, total_q), dtype=torch.float, device=query.device
  5293. )
  5294. if return_debug_mask:
  5295. blocksize_c = 128 if head_dim > 64 else 256
  5296. max_seqlen_k = math.ceil(max_seqlen_batch_q / blocksize_c)
  5297. if max_seqlen_batch_k <= 128:
  5298. max_seqlen_k = 128
  5299. elif max_seqlen_batch_k <= 256:
  5300. max_seqlen_k = 256
  5301. debug_mask = torch.empty(
  5302. (batch_size, num_heads, max_seqlen_batch_q, max_seqlen_k),
  5303. dtype=query.dtype,
  5304. device=query.device,
  5305. )
  5306. else:
  5307. debug_mask = torch.empty(0, dtype=query.dtype, device=query.device)
  5308. # See Note [Seed and Offset]
  5309. # See [Note] BC breaking change to flash seed/offset
  5310. seed, offset = None, None
  5311. if torch.version.hip and torch.cuda.is_available():
  5312. # Maintain old path on AMD
  5313. seed = torch.empty((), dtype=torch.long, device="meta")
  5314. offset = torch.empty((), dtype=torch.long, device="meta")
  5315. else:
  5316. seed = torch.empty((2), dtype=torch.uint64, device="meta")
  5317. offset = torch.empty((), dtype=torch.uint64, device="meta")
  5318. return (
  5319. attention,
  5320. logsumexp,
  5321. seed,
  5322. offset,
  5323. debug_mask,
  5324. )
  5325. @register_meta(
  5326. [
  5327. aten._flash_attention_backward,
  5328. ]
  5329. )
  5330. def meta__flash_attention_backward(
  5331. grad_out: Tensor,
  5332. query: Tensor,
  5333. key: Tensor,
  5334. value: Tensor,
  5335. out: Tensor,
  5336. logsumexp: Tensor,
  5337. cum_seq_q: Tensor,
  5338. cum_seq_k: Tensor,
  5339. max_q: int,
  5340. max_k: int,
  5341. dropout_p: float,
  5342. is_causal: bool,
  5343. philox_seed: Tensor,
  5344. philox_offset: Tensor,
  5345. scale: float | None = None,
  5346. window_size_left: int | None = None,
  5347. window_size_right: int | None = None,
  5348. ):
  5349. grad_query = torch.empty_like(query)
  5350. grad_key = torch.empty_like(key)
  5351. grad_value = torch.empty_like(value)
  5352. return grad_query, grad_key, grad_value
  5353. @register_meta(
  5354. [
  5355. aten._efficient_attention_forward,
  5356. ]
  5357. )
  5358. def meta__efficient_attention_forward(
  5359. query: Tensor,
  5360. key: Tensor,
  5361. value: Tensor,
  5362. bias: Tensor | None,
  5363. cu_seqlens_q: Tensor | None,
  5364. cu_seqlens_k: Tensor | None,
  5365. max_seqlen_q: int | None,
  5366. max_seqlen_k: int | None,
  5367. dropout_p: float,
  5368. custom_mask_type: int,
  5369. compute_log_sumexp: bool = False,
  5370. scale: float | None = None,
  5371. causal_diagonal: Tensor | None = None,
  5372. seqlen_k: Tensor | None = None,
  5373. window_size: int | None = None,
  5374. ):
  5375. B = query.size(0)
  5376. M = query.size(1)
  5377. N = key.size(1)
  5378. num_heads = query.size(-2)
  5379. Kv = value.size(-1)
  5380. res = torch.empty(B, M, num_heads, Kv, dtype=query.dtype, device=query.device)
  5381. logsumexp_batch_dim = cu_seqlens_q.size(0) - 1 if (cu_seqlens_q is not None) else B
  5382. actual_max_seqlen_q = M
  5383. if cu_seqlens_q is not None:
  5384. assert max_seqlen_q is not None
  5385. actual_max_seqlen_q = max_seqlen_q
  5386. actual_max_seqlen_k = max_seqlen_k if max_seqlen_k is not None else N
  5387. logsumexp_dim = (
  5388. math.ceil(actual_max_seqlen_q / 32) * 32 if compute_log_sumexp else 0
  5389. )
  5390. logsum_exp = torch.empty(
  5391. (logsumexp_batch_dim, num_heads, logsumexp_dim),
  5392. dtype=torch.float,
  5393. device=query.device,
  5394. )
  5395. # See Note [Seed and Offset]:
  5396. seed = torch.empty((), dtype=torch.long, device="meta")
  5397. offset = torch.empty((), dtype=torch.long, device="meta")
  5398. return res, logsum_exp, seed, offset, actual_max_seqlen_q, actual_max_seqlen_k
  5399. @register_meta(
  5400. [
  5401. aten._efficient_attention_backward,
  5402. ]
  5403. )
  5404. def meta__efficient_attention_backward(
  5405. grad_out: Tensor,
  5406. query: Tensor,
  5407. key: Tensor,
  5408. value: Tensor,
  5409. bias: Tensor | None,
  5410. cu_seqlens_q: Tensor | None,
  5411. cu_seqlens_k: Tensor | None,
  5412. max_seqlen_q: torch.SymInt,
  5413. max_seqlen_k: torch.SymInt,
  5414. logsumexp: Tensor,
  5415. dropout_p: float,
  5416. philox_seed: Tensor,
  5417. philox_offset: Tensor,
  5418. custom_mask_type: int,
  5419. bias_requires_grad: bool,
  5420. scale: float | None = None,
  5421. num_splits_key: int | None = None,
  5422. shared_storage_dqdkdv: bool = False,
  5423. ):
  5424. if shared_storage_dqdkdv:
  5425. torch._check(
  5426. query.shape[1] == key.shape[1],
  5427. lambda: "seqlen must match for `shared_storage_dqdkdv",
  5428. )
  5429. torch._check(
  5430. query.shape[3] == key.shape[3],
  5431. lambda: "embedding dim must match for `shared_storage_dqdkdv",
  5432. )
  5433. chunk = torch.empty(
  5434. (*query.shape[0:-2], 3, query.shape[-2], query.shape[-1]),
  5435. dtype=query.dtype,
  5436. device=query.device,
  5437. )
  5438. grad_query = chunk.select(-3, 0)
  5439. grad_key = chunk.select(-3, 1)
  5440. grad_value = chunk.select(-3, 2)
  5441. else:
  5442. grad_query = torch.empty_like(query)
  5443. grad_key = torch.empty_like(key)
  5444. grad_value = torch.empty_like(value)
  5445. if bias is not None:
  5446. lastDim = bias.size(-1)
  5447. lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
  5448. new_sizes = list(bias.size())
  5449. new_sizes[-1] = lastDimAligned
  5450. grad_bias = torch.empty(new_sizes, dtype=bias.dtype, device=bias.device)
  5451. grad_bias = grad_bias[..., :lastDim]
  5452. else:
  5453. grad_bias = torch.empty((), device=query.device)
  5454. return grad_query, grad_key, grad_value, grad_bias
  5455. def _check_scaled_mm_sizes(
  5456. self: torch.Tensor,
  5457. mat2: torch.Tensor,
  5458. scale_a: torch.Tensor,
  5459. scale_b: torch.Tensor,
  5460. bias: torch.Tensor | None = None,
  5461. scale_result: torch.Tensor | None = None,
  5462. out_dtype: torch.dtype | None = None,
  5463. use_fast_accum: bool = False,
  5464. ):
  5465. def is_fp8_or_fp4_type(dtype):
  5466. return dtype in (
  5467. torch.float8_e4m3fn,
  5468. torch.float8_e5m2,
  5469. torch.float8_e4m3fnuz,
  5470. torch.float8_e5m2fnuz,
  5471. torch.float4_e2m1fn_x2,
  5472. )
  5473. torch._check(
  5474. self.dim() == 2 and mat2.dim() == 2,
  5475. lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
  5476. )
  5477. torch._check(
  5478. is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype),
  5479. lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
  5480. )
  5481. if device_hint(self) == "cuda" or device_hint(self) == "xpu":
  5482. def is_row_major(stride):
  5483. return stride[0] > stride[1] and stride[1] == 1
  5484. def is_col_major(stride):
  5485. return stride[0] == 1 and stride[1] > 1
  5486. def has_zero_dim(tensor_2d):
  5487. return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0
  5488. torch._check(
  5489. is_row_major(self.stride()) or has_zero_dim(self),
  5490. lambda: f"self must be row_major, got stride {self.stride()}",
  5491. )
  5492. torch._check(
  5493. is_col_major(mat2.stride()) or has_zero_dim(mat2),
  5494. lambda: f"mat2 must be col_major, got stride {mat2.stride()}",
  5495. )
  5496. torch._check(
  5497. self.size(1) % 16 == 0,
  5498. lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}",
  5499. )
  5500. torch._check(
  5501. mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
  5502. lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}",
  5503. )
  5504. # determine scaling type and check input dimensions (refer to Blas.cpp op)
  5505. m, _k = self.shape
  5506. n = mat2.size(1)
  5507. is_blockwise_scaling = (
  5508. (
  5509. scale_a.dtype == torch.float8_e8m0fnu
  5510. and scale_b.dtype == torch.float8_e8m0fnu
  5511. )
  5512. or (
  5513. scale_a.dtype == torch.float8_e4m3fn
  5514. and scale_b.dtype == torch.float8_e4m3fn
  5515. )
  5516. ) # note: this applies to blockwise scaling for non-FP8 types (FP8 accepts FP32 scales)
  5517. if scale_a.numel() == 1 and scale_b.numel() == 1:
  5518. # tensorwise scaling
  5519. torch._check(
  5520. scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
  5521. lambda: "For tensorwise scaling, both scale_a and scale_b must be float (fp32) tensors.",
  5522. )
  5523. elif is_blockwise_scaling:
  5524. # blockwise scaling
  5525. if scale_a.dtype == torch.float8_e4m3fn:
  5526. # NVIDIA's nvfp4 recipe:
  5527. # * block size is 16 elements packed (32 unpacked)
  5528. # * _k needs to be translated to the unpacked version
  5529. block_size_k = 16
  5530. _k = _k * 2
  5531. else:
  5532. block_size_k = 32
  5533. block_size_mn = 128
  5534. num_k_blocks = ceil_div(_k, block_size_k)
  5535. padded_num_k_blocks = ceil_div(num_k_blocks, 4) * 4
  5536. expected_a_size = (
  5537. block_size_mn * ceil_div(m, block_size_mn) * padded_num_k_blocks
  5538. )
  5539. expected_b_size = (
  5540. block_size_mn * ceil_div(n, block_size_mn) * padded_num_k_blocks
  5541. )
  5542. if (
  5543. scale_a.numel() == expected_a_size
  5544. and scale_b.numel() == expected_b_size
  5545. ):
  5546. torch._check(
  5547. scale_a.is_contiguous(),
  5548. lambda: "scale_a must be contiguous",
  5549. )
  5550. torch._check(
  5551. scale_b.is_contiguous(),
  5552. lambda: "scale_b must be contiguous",
  5553. )
  5554. else:
  5555. torch._check(
  5556. False,
  5557. lambda: (
  5558. "Invalid blockwise scaling configuration. "
  5559. f"For blockwise scaling, scale_a should have {expected_a_size} elements, got {scale_a.numel()}, "
  5560. f"scale_b should have {expected_b_size} elements, got {scale_b.numel()}."
  5561. ),
  5562. )
  5563. else:
  5564. torch._check(
  5565. scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32,
  5566. lambda: "For rowwise scaling, both scale_a and scale_b must be float (fp32) tensors.",
  5567. )
  5568. # for rowwise scaling, enforce 2D input tensors
  5569. torch._check(
  5570. scale_a.dim() == 2 and scale_b.dim() == 2,
  5571. lambda: f"For non-tensorwise scaling, scale tensors must be 2D, but got {scale_a.dim()=} and {scale_b.dim()=}",
  5572. )
  5573. if (
  5574. scale_a.size(0) == m
  5575. and scale_a.size(1) == 1
  5576. and scale_b.size(0) == 1
  5577. and scale_b.size(1) == n
  5578. ):
  5579. # rowwise scaling
  5580. torch._check(
  5581. scale_a.is_contiguous() and scale_b.is_contiguous(),
  5582. lambda: "Both scale_a and scale_b must be contiguous for rowwise scaling.",
  5583. )
  5584. elif (
  5585. scale_a.size(0) == m
  5586. and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
  5587. and scale_b.size(1) == ceil_div(n, 128)
  5588. ):
  5589. # (BlockWise1x128, BlockWise128x128)
  5590. pass # do nothing, but do not error
  5591. elif (
  5592. scale_a.size(0) == m
  5593. and scale_a.size(1) == scale_b.size(0) == ceil_div(_k, 128)
  5594. and scale_b.size(1) == n
  5595. ):
  5596. # (BlockWise1x128, BlockWise1x128)
  5597. pass # do nothing, but do not error
  5598. else:
  5599. # does not match any valid scaling type
  5600. torch._check(
  5601. False,
  5602. lambda: (
  5603. "Invalid scaling configuration. "
  5604. "For tensorwise scaling, both scales should be scalar. "
  5605. f"For rowwise scaling, scale_a should be ({m}, 1), scale_b should be (1, {n}). "
  5606. f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
  5607. + f"scale_b should be ({ceil_div(_k, 128)}, {ceil_div(n, 128)}). "
  5608. f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({m}, {ceil_div(_k, 128)}), "
  5609. + f"scale_b should be ({ceil_div(_k, 128)}, {n}). "
  5610. f"Got scale_a.size()=({scale_a.size(0)}, {scale_a.size(1)}) "
  5611. f"and scale_b.size()=({scale_b.size(0)}, {scale_b.size(1)})"
  5612. ),
  5613. )
  5614. _out_dtype = out_dtype if out_dtype is not None else self.dtype
  5615. return torch.empty(self.size(0), mat2.size(1), dtype=_out_dtype, device=self.device)
  5616. @register_meta([aten._scaled_mm.default])
  5617. def meta_scaled_mm(
  5618. self: torch.Tensor,
  5619. mat2: torch.Tensor,
  5620. scale_a: torch.Tensor,
  5621. scale_b: torch.Tensor,
  5622. bias: torch.Tensor | None = None,
  5623. scale_result: torch.Tensor | None = None,
  5624. out_dtype: torch.dtype | None = None,
  5625. use_fast_accum: bool = False,
  5626. ):
  5627. return _check_scaled_mm_sizes(
  5628. self, mat2, scale_a, scale_b, bias, scale_result, out_dtype, use_fast_accum
  5629. )
  5630. def _check_scaled_mm_sizes_v2(
  5631. self: torch.Tensor,
  5632. mat2: torch.Tensor,
  5633. scale_a: list[torch.Tensor],
  5634. scale_recipe_a: list[ScalingType],
  5635. scale_b: list[torch.Tensor],
  5636. scale_recipe_b: list[ScalingType],
  5637. bias: torch.Tensor | None = None,
  5638. out_dtype: torch.dtype | None = None,
  5639. swizzle_a: list[SwizzleType] | None = None,
  5640. swizzle_b: list[SwizzleType] | None = None,
  5641. use_fast_accum: bool = False,
  5642. ):
  5643. def is_fp8_or_fp4_type(dtype):
  5644. return dtype in (
  5645. torch.float8_e4m3fn,
  5646. torch.float8_e5m2,
  5647. torch.float8_e4m3fnuz,
  5648. torch.float8_e5m2fnuz,
  5649. torch.float4_e2m1fn_x2,
  5650. )
  5651. def is_fp4_type(dtype):
  5652. return dtype in (torch.float4_e2m1fn_x2,)
  5653. torch._check(
  5654. self.dim() == 2 and mat2.dim() == 2,
  5655. lambda: f"Inputs must be 2D but got self.dim()={self.dim()} and mat2.dim()={mat2.dim()}",
  5656. )
  5657. torch._check(
  5658. is_fp8_or_fp4_type(self.dtype) and is_fp8_or_fp4_type(mat2.dtype),
  5659. lambda: f"Expected both inputs to be fp8 or fp4 types but got self.dtype={self.dtype} and mat2.dtype={mat2.dtype}",
  5660. )
  5661. # Passed tensors:
  5662. # self: [M, K]
  5663. # mat2: [K, N]
  5664. M = self.shape[0]
  5665. K = self.shape[1]
  5666. N = mat2.shape[1]
  5667. # If we're using fp4, using fp4x2 packed format - adjust K appropriately
  5668. if is_fp4_type(self.dtype) and is_fp4_type(mat2.dtype):
  5669. K_packed_multiplier = 2
  5670. K *= K_packed_multiplier
  5671. scale_recipe_a = [ScalingType(si) for si in scale_recipe_a]
  5672. scale_recipe_b = [ScalingType(si) for si in scale_recipe_b]
  5673. if swizzle_a:
  5674. swizzle_a = [SwizzleType(si) for si in swizzle_a]
  5675. else:
  5676. swizzle_a = [
  5677. SwizzleType.NO_SWIZZLE,
  5678. ]
  5679. if swizzle_b:
  5680. swizzle_b = [SwizzleType(si) for si in swizzle_b]
  5681. else:
  5682. swizzle_b = [
  5683. SwizzleType.NO_SWIZZLE,
  5684. ]
  5685. if device_hint(self) == "cuda" or device_hint(self) == "xpu":
  5686. def is_row_major(stride):
  5687. return stride[0] > stride[1] and stride[1] == 1
  5688. def is_col_major(stride):
  5689. return stride[0] == 1 and stride[1] > 1
  5690. def has_zero_dim(tensor_2d):
  5691. return tensor_2d.size(0) == 0 or tensor_2d.size(1) == 0
  5692. torch._check(
  5693. is_row_major(self.stride()) or has_zero_dim(self),
  5694. lambda: f"self must be row_major, got stride {self.stride()}",
  5695. )
  5696. torch._check(
  5697. is_col_major(mat2.stride()) or has_zero_dim(mat2),
  5698. lambda: f"mat2 must be col_major, got stride {mat2.stride()}",
  5699. )
  5700. torch._check(
  5701. self.size(1) % 16 == 0,
  5702. lambda: f"Expected self.size(1) to be divisible by 16, but got self.size(1)={self.size(1)}",
  5703. )
  5704. torch._check(
  5705. mat2.size(0) % 16 == 0 and mat2.size(1) % 16 == 0,
  5706. lambda: f"Expected both dimensions of mat2 to be divisible by 16 but got {mat2.shape}",
  5707. )
  5708. def is_tensorwise(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5709. return (
  5710. len(recipe_a) == 1
  5711. and len(recipe_b) == 1
  5712. and recipe_a[0] == ScalingType.TensorWise
  5713. and recipe_b[0] == ScalingType.TensorWise
  5714. )
  5715. def is_rowwise(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5716. return (
  5717. len(recipe_a) == 1
  5718. and len(recipe_b) == 1
  5719. and recipe_a[0] == ScalingType.RowWise
  5720. and recipe_b[0] == ScalingType.RowWise
  5721. )
  5722. def is_mx(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5723. return (
  5724. len(recipe_a) == 1
  5725. and len(recipe_b) == 1
  5726. and recipe_a[0] == ScalingType.BlockWise1x32
  5727. and recipe_b[0] == ScalingType.BlockWise1x32
  5728. )
  5729. def is_nv(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5730. return (
  5731. len(recipe_a) == 2
  5732. and len(recipe_b) == 2
  5733. and recipe_a[0] == ScalingType.BlockWise1x16
  5734. and recipe_a[1] == ScalingType.TensorWise
  5735. and recipe_b[0] == ScalingType.BlockWise1x16
  5736. and recipe_b[1] == ScalingType.TensorWise
  5737. )
  5738. def is_1x128_1x128(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5739. return (
  5740. len(recipe_a) == 1
  5741. and len(recipe_b) == 1
  5742. and recipe_a[0] == ScalingType.BlockWise1x128
  5743. and recipe_b[0] == ScalingType.BlockWise1x128
  5744. )
  5745. def is_1x128_128x128(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5746. return (
  5747. len(recipe_a) == 1
  5748. and len(recipe_b) == 1
  5749. and recipe_a[0] == ScalingType.BlockWise1x128
  5750. and recipe_b[0] == ScalingType.BlockWise128x128
  5751. )
  5752. def is_128x128_1x128(recipe_a: list[ScalingType], recipe_b: list[ScalingType]):
  5753. return (
  5754. len(recipe_a) == 1
  5755. and len(recipe_b) == 1
  5756. and recipe_a[0] == ScalingType.BlockWise128x128
  5757. and recipe_b[0] == ScalingType.BlockWise1x128
  5758. )
  5759. # Given scaling types, check input dimensions
  5760. if is_tensorwise(scale_recipe_a, scale_recipe_b):
  5761. # TensorWise
  5762. torch._check(
  5763. scale_a[0].numel() == 1
  5764. and scale_b[0].numel() == 1
  5765. and scale_a[0].dtype == torch.float32
  5766. and scale_b[0].dtype == torch.float32,
  5767. lambda: "For Tensorwise scaling, both scale_a and scale_b must be single element float (fp32) tensors",
  5768. )
  5769. elif is_rowwise(scale_recipe_a, scale_recipe_b):
  5770. torch._check(
  5771. scale_a[0].shape[0] == M
  5772. and scale_a[0].numel() == M
  5773. and scale_a[0].dtype == torch.float32
  5774. and scale_b[0].numel() == N
  5775. and scale_b[0].dtype == torch.float32,
  5776. lambda: (
  5777. f"For Rowwise scaling, scale_a must have {self.shape[0]} elements (got: {scale_a[0].numel()})"
  5778. f", and scale_b must have {mat2.shape[1]} elements (got: {scale_b[0].numel()})"
  5779. ),
  5780. )
  5781. elif is_1x128_1x128(scale_recipe_a, scale_recipe_b):
  5782. # A, B are fp8, scales are fp32
  5783. # As: [M x K // 128], stride: [1, M]
  5784. # Bs: [N x K // 128], stride: [1, N]
  5785. types_ok = (
  5786. scale_a[0].dtype == torch.float32 and scale_b[0].dtype == torch.float32
  5787. )
  5788. sa = scale_a[0]
  5789. scale_a_ok = (
  5790. sa.shape[0] == M
  5791. and sa.shape[1] == K // 128
  5792. and sa.stride(0) == 1
  5793. and (sa.stride(1) == M or (sa.shape[1] == 1 and sa.stride(1) == 1))
  5794. )
  5795. sb = scale_b[0]
  5796. scale_b_ok = (
  5797. sb.shape[0] == N
  5798. and sb.shape[1] == K // 128
  5799. and sb.stride(0) == 1
  5800. and (sb.stride(1) == N or (sb.shape[1] == 1 and sb.stride(1) == 1))
  5801. )
  5802. torch._check(
  5803. types_ok and scale_a_ok and scale_b_ok,
  5804. lambda: (
  5805. "For 1x128 x 1x128 blockwise scaling, "
  5806. f"scale a must have shape [{M}, {K // 128}] (got: {sa.shape}) and stride [1, {M}] (got: {sa.stride})"
  5807. f"scale b must have shape [{N}, {K // 128}] (got: {sb.shape}) and stride [1, {N}] (got: {sb.stride})"
  5808. ),
  5809. )
  5810. elif is_128x128_1x128(scale_recipe_a, scale_recipe_b):
  5811. # A, B are fp8, scales are fp32
  5812. # L4 = round_up(K // 128, 4)
  5813. # As: [L4 x M // 128], stride: [1, L4]
  5814. # Bs: [N x K // 128], stride: [1, N]
  5815. types_ok = (
  5816. scale_a[0].dtype == torch.float32 and scale_b[0].dtype == torch.float32
  5817. )
  5818. L4 = round_up(K / 128, 4)
  5819. sa = scale_a[0]
  5820. scale_a_ok = (
  5821. sa.shape[0] == L4
  5822. and sa.shape[1] == M // 128
  5823. and sa.stride(0) == 1
  5824. and (sa.stride(1) == L4 or (sa.shape[1] == 1 and sa.stride(1) == 1))
  5825. )
  5826. sb = scale_b[0]
  5827. scale_b_ok = (
  5828. sb.shape[0] == N
  5829. and sb.shape[1] == K // 128
  5830. and sb.stride(0) == 1
  5831. and (sb.stride(1) == N or (sb.shape[1] == 1 and sb.stride(1) == 1))
  5832. )
  5833. torch._check(
  5834. types_ok and scale_a_ok and scale_b_ok,
  5835. lambda: (
  5836. "For 128x128 x 1x128 blockwise scaling, L4 = {round_up(K / 128, 4)}, "
  5837. f"scale a must have shape [{L4}, {M // 128}] (got: {sa.shape}) and stride [1, {L4}] (got: {sa.stride})"
  5838. f"scale b must have shape [{N}, {K // 128}] (got: {sb.shape}) and stride [1, {N}] (got: {sb.stride})"
  5839. ),
  5840. )
  5841. elif is_1x128_128x128(scale_recipe_a, scale_recipe_b):
  5842. # A, B are fp8, scales are fp32
  5843. # L4 = round_up(K // 128, 4)
  5844. # As: [M x K // 128], stride: [1, M]
  5845. # Bs: [L4 x N // 128], stride: [1, L4]
  5846. types_ok = (
  5847. scale_a[0].dtype == torch.float32 and scale_b[0].dtype == torch.float32
  5848. )
  5849. L4 = round_up(K / 128, 4)
  5850. sa = scale_a[0]
  5851. scale_a_ok = (
  5852. sa.shape[0] == M
  5853. and sa.shape[1] == K // 128
  5854. and sa.stride(0) == 1
  5855. and (sa.stride(1) == M or (sa.shape[1] == 1 and sa.stride(1) == 1))
  5856. )
  5857. sb = scale_b[0]
  5858. scale_b_ok = (
  5859. sb.shape[0] == L4
  5860. and sb.shape[1] == N // 128
  5861. and sb.stride(0) == 1
  5862. and (sb.stride(1) == L4 or (sb.shape[1] == 1 and sb.stride(1) == 1))
  5863. )
  5864. torch._check(
  5865. types_ok and scale_a_ok and scale_b_ok,
  5866. lambda: (
  5867. "For 1x128 x 128x128 blockwise scaling, L4 = {round_up(K / 128, 4)}, "
  5868. f"scale a must have shape [{M}, {K // 128}] (got: {sa.shape}) and stride [1, {M}] (got: {sa.stride})"
  5869. f"scale b must have shape [{L4}, {N // 128}] (got: {sb.shape}) and stride [1, {L4}] (got: {sb.stride})"
  5870. ),
  5871. )
  5872. elif is_mx(scale_recipe_a, scale_recipe_b):
  5873. if torch.version.hip:
  5874. # Note(slayton58): These mirror ROCm in ScaledBlas.cpp, but I think they're wrong..
  5875. expected_scale_a_elems = ceil_div(self.shape[0], 32) * self.shape[1]
  5876. expected_scale_b_elems = ceil_div(self.shape[1], 32) * self.shape[0]
  5877. expected_swizzle = SwizzleType.NO_SWIZZLE
  5878. else:
  5879. expected_scale_a_elems = round_up(self.shape[0], 128) * round_up(
  5880. ceil_div(self.shape[1], 32), 4
  5881. )
  5882. expected_scale_b_elems = round_up(mat2.shape[1], 128) * round_up(
  5883. ceil_div(self.shape[1], 32), 4
  5884. )
  5885. expected_swizzle = SwizzleType.SWIZZLE_32_4_4
  5886. torch._check(
  5887. scale_a[0].numel() == expected_scale_a_elems
  5888. and scale_a[0].dtype == torch.float8_e8m0fnu
  5889. and scale_b[0].numel() == expected_scale_b_elems
  5890. and scale_b[0].dtype == torch.float8_e8m0fnu
  5891. and swizzle_a[0] == expected_swizzle
  5892. and swizzle_b[0] == expected_swizzle,
  5893. lambda: (
  5894. f"for MX scaling scale_a must have {expected_scale_a_elems} (got: {scale_a[0].numel()}) "
  5895. f"and scale_b must have {expected_scale_b_elems} (got: {scale_b[0].numel()}). Scales must "
  5896. f"have types {torch.float8_e8m0fnu} (for self: {scale_a[0].dtype}, mat_b: {scale_b[0].dtype}) "
  5897. f"Must have swizzle type {expected_swizzle} (got self: {swizzle_a[0]}, mat_b: {swizzle_b[0]})"
  5898. ),
  5899. )
  5900. elif is_nv(scale_recipe_a, scale_recipe_b):
  5901. expected_scale_a_elems = round_up(M, 128) * round_up(ceil_div(K, 16), 4)
  5902. expected_scale_b_elems = round_up(N, 128) * round_up(ceil_div(K, 16), 4)
  5903. expected_swizzle = SwizzleType.SWIZZLE_32_4_4
  5904. torch._check(
  5905. scale_a[0].numel() == expected_scale_a_elems
  5906. and scale_a[0].dtype == torch.float8_e4m3fn
  5907. and scale_a[1].numel() == 1
  5908. and scale_a[1].dtype == torch.float32
  5909. and scale_b[0].numel() == expected_scale_b_elems
  5910. and scale_b[0].dtype == torch.float8_e4m3fn
  5911. and scale_b[1].numel() == 1
  5912. and scale_b[1].dtype == torch.float32
  5913. and swizzle_a[0] == expected_swizzle
  5914. and swizzle_b[0] == expected_swizzle,
  5915. lambda: (
  5916. f"for NV scaling scale_a must have {expected_scale_a_elems} (got: {scale_a[0].numel()}) "
  5917. f"and scale_b must have {expected_scale_b_elems} (got: {scale_b[0].numel()}). Must have "
  5918. f"swizzle type {expected_swizzle} (got self: {swizzle_a[0]}, mat_b: {swizzle_b[0]})"
  5919. ),
  5920. )
  5921. else:
  5922. torch._check(
  5923. False,
  5924. lambda: (
  5925. "Invalid scaling configuration. "
  5926. "For tensorwise scaling, both scales should be scalar. "
  5927. f"For rowwise scaling, scale_a should be ({M}, 1), scale_b should be (1, {N}). "
  5928. f"For (BlockWise1x128, BlockWise128x128), scale_a should be ({M}, {ceil_div(K, 128)}), "
  5929. + f"scale_b should be ({ceil_div(K, 128)}, {ceil_div(N, 128)}). "
  5930. f"For (BlockWise1x128, BlockWise1x128), scale_a should be ({M}, {ceil_div(K, 128)}), "
  5931. + f"scale_b should be ({ceil_div(K, 128)}, {N}). "
  5932. f"Got scale_a.size()=({scale_a[0].size(0)}, {scale_a[0].size(1)}) "
  5933. f"and scale_b.size()=({scale_b[0].size(0)}, {scale_b[0].size(1)})"
  5934. ),
  5935. )
  5936. _out_dtype = out_dtype if out_dtype is not None else self.dtype
  5937. return torch.empty(M, N, dtype=_out_dtype, device=self.device)
  5938. @register_meta([aten._scaled_mm_v2.default])
  5939. def meta_scaled_mm_v2(
  5940. self: torch.Tensor,
  5941. mat2: torch.Tensor,
  5942. scale_a: list[torch.Tensor],
  5943. scale_recipe_a: list[ScalingType],
  5944. swizzle_a: list[SwizzleType],
  5945. scale_b: list[torch.Tensor],
  5946. scale_recipe_b: list[ScalingType],
  5947. swizzle_b: list[SwizzleType],
  5948. bias: torch.Tensor | None = None,
  5949. output_dtype: torch.dtype | None = None,
  5950. contraction_dims: list[int] | None = None,
  5951. use_fast_accum: bool = False,
  5952. ):
  5953. return _check_scaled_mm_sizes_v2(
  5954. self,
  5955. mat2,
  5956. scale_a,
  5957. scale_recipe_a,
  5958. scale_b,
  5959. scale_recipe_b,
  5960. bias=bias,
  5961. out_dtype=output_dtype,
  5962. swizzle_a=swizzle_a,
  5963. swizzle_b=swizzle_b,
  5964. use_fast_accum=use_fast_accum,
  5965. )
  5966. @register_meta([aten.scatter_reduce.two, aten.scatter_reduce.two_out])
  5967. @out_wrapper()
  5968. def meta_scatter_reduce_two(self, dim, index, src, reduce, include_self=True):
  5969. scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
  5970. return self.new_empty(self.shape)
  5971. @register_meta(aten.scatter_reduce_.two)
  5972. def meta_scatter_reduce__two(self, dim, index, src, reduce, include_self=True):
  5973. scatter_meta_impl(self, dim, index, src, reduce, use_new_options=True)
  5974. return self
  5975. @register_meta([aten.multinomial.default, aten.multinomial.out])
  5976. @out_wrapper()
  5977. def meta_multinomial(input, num_samples, replacement=False, *, generator=None):
  5978. torch._check(
  5979. 0 < input.dim() <= 2,
  5980. lambda: f"The probability distributions dimensions must be 1 or 2, but got {input.dim()}",
  5981. )
  5982. if input.dim() == 1:
  5983. return torch.empty(num_samples, dtype=torch.long, device=input.device)
  5984. return torch.empty(
  5985. input.size(0), num_samples, dtype=torch.long, device=input.device
  5986. )
  5987. def multiply_integers(vs):
  5988. r = 1
  5989. for v in vs:
  5990. r *= v
  5991. return r
  5992. def upsample_common_check(input_size, output_size, num_spatial_dims):
  5993. torch._check(
  5994. len(output_size) == num_spatial_dims,
  5995. lambda: f"It is expected output_size equals to {num_spatial_dims}, but got size {len(output_size)}",
  5996. )
  5997. expected_input_dims = num_spatial_dims + 2 # N, C, ...
  5998. torch._check(
  5999. len(input_size) == expected_input_dims,
  6000. lambda: f"It is expected input_size equals to {expected_input_dims}, but got size {len(input_size)}",
  6001. )
  6002. torch._check(
  6003. all(s > 0 for s in input_size[2:]) and all(s > 0 for s in output_size),
  6004. lambda: f"Input and output sizes should be greater than 0, but got "
  6005. f"input size {input_size} and output size {output_size}",
  6006. )
  6007. nbatch, channels = input_size[:2]
  6008. return (nbatch, channels, *output_size)
  6009. @register_meta(
  6010. [aten.upsample_nearest1d.default, aten._upsample_nearest_exact1d.default]
  6011. )
  6012. def upsample_nearest1d(input, output_size, scales=None):
  6013. torch._check(
  6014. input.numel() != 0 or multiply_integers(input.size()[1:]),
  6015. lambda: f"Non-empty 3D data tensor expected but got a tensor with sizes {input.size()}",
  6016. )
  6017. full_output_size = upsample_common_check(
  6018. input.size(), output_size, num_spatial_dims=1
  6019. )
  6020. return input.new_empty(full_output_size).to(
  6021. memory_format=utils.suggest_memory_format(input)
  6022. )
  6023. @register_meta(
  6024. [aten.upsample_nearest2d.default, aten._upsample_nearest_exact2d.default]
  6025. )
  6026. def upsample_nearest2d(input, output_size, scales_h=None, scales_w=None):
  6027. torch._check(
  6028. input.numel() != 0 or multiply_integers(input.size()[1:]),
  6029. lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
  6030. )
  6031. full_output_size = upsample_common_check(
  6032. input.size(), output_size, num_spatial_dims=2
  6033. )
  6034. output = input.new_empty(full_output_size)
  6035. # convert output to correct memory format, if necessary
  6036. memory_format = utils.suggest_memory_format(input)
  6037. # following "heuristic: only use channels_last path when it's faster than the contiguous path"
  6038. _, n_channels, _, _ = input.shape
  6039. if input.device.type == "cuda" and n_channels < 4:
  6040. memory_format = torch.contiguous_format
  6041. output = output.contiguous(memory_format=memory_format)
  6042. return output
  6043. @register_meta(
  6044. [
  6045. aten.upsample_nearest2d_backward.default,
  6046. aten._upsample_nearest_exact2d_backward.default,
  6047. ]
  6048. )
  6049. def upsample_nearest2d_backward(
  6050. grad_output: Tensor,
  6051. output_size: Sequence[int | torch.SymInt],
  6052. input_size: Sequence[int | torch.SymInt],
  6053. scales_h: float | None = None,
  6054. scales_w: float | None = None,
  6055. ):
  6056. full_output_size = upsample_common_check(
  6057. input_size, output_size, num_spatial_dims=2
  6058. )
  6059. torch._check(
  6060. grad_output.ndim == 4,
  6061. lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
  6062. )
  6063. for i in range(4):
  6064. torch._check(
  6065. grad_output.size(i) == full_output_size[i],
  6066. lambda: (
  6067. f"Expected grad_output to have the same shape as output;"
  6068. f" output.size({i}) = {full_output_size[i]}"
  6069. f" but got grad_output.size({i}) = {grad_output.size(i)}"
  6070. ),
  6071. )
  6072. return grad_output.new_empty(input_size).to(
  6073. memory_format=utils.suggest_memory_format(grad_output)
  6074. ) # type: ignore[call-overload]
  6075. @register_meta(
  6076. [aten.upsample_nearest3d.default, aten._upsample_nearest_exact3d.default]
  6077. )
  6078. def upsample_nearest3d(input, output_size, scales_d=None, scales_h=None, scales_w=None):
  6079. torch._check(
  6080. input.numel() != 0 or multiply_integers(input.size()[1:]),
  6081. lambda: f"Non-empty 5D data tensor expected but got a tensor with sizes {input.size()}",
  6082. )
  6083. full_output_size = upsample_common_check(
  6084. input.size(), output_size, num_spatial_dims=3
  6085. )
  6086. return input.new_empty(full_output_size).to(
  6087. memory_format=utils.suggest_memory_format(input)
  6088. )
  6089. @register_meta(
  6090. [
  6091. aten.sort.default,
  6092. aten.sort.stable,
  6093. aten.sort.values,
  6094. aten.sort.values_stable,
  6095. ]
  6096. )
  6097. def meta_sort(self, stable=None, dim=-1, descending=False, values=None, indices=None):
  6098. v, i = torch.empty_like(self), torch.empty_like(self, dtype=torch.int64)
  6099. if values is not None and indices is not None:
  6100. assert isinstance(values, TensorLike)
  6101. assert isinstance(indices, TensorLike)
  6102. # Makes sure values and indices have the same strides. For cases where
  6103. # these have different shapes, like (5, 10, 5) and (0) in msort.
  6104. out_shape = v.shape
  6105. out_stride = v.stride()
  6106. values = _maybe_resize_out(values, out_shape)
  6107. indices = _maybe_resize_out(indices, out_shape)
  6108. values.as_strided_(out_shape, out_stride)
  6109. indices.as_strided_(out_shape, out_stride)
  6110. _safe_copy_out(copy_from=v, copy_to=values) # type: ignore[arg-type]
  6111. _safe_copy_out(copy_from=i, copy_to=indices) # type: ignore[arg-type]
  6112. return values, indices
  6113. return v, i
  6114. def rnn_cell_checkSizes(
  6115. input_gates,
  6116. hidden_gates,
  6117. input_bias,
  6118. hidden_bias,
  6119. factor,
  6120. prev_hidden,
  6121. ):
  6122. torch._check(input_gates.ndim == 2, lambda: f"{input_gates.ndim} != 2")
  6123. torch._check(
  6124. input_gates.shape == hidden_gates.shape,
  6125. lambda: f"{input_gates.shape} != {hidden_gates.shape}",
  6126. )
  6127. gates_size = input_gates.size(1)
  6128. if input_bias is not None:
  6129. torch._check(input_bias.ndim == 1, lambda: f"{input_bias.ndim} != 1")
  6130. torch._check(
  6131. input_bias.numel() == gates_size,
  6132. lambda: f"{input_bias.numel()} != {gates_size}",
  6133. )
  6134. torch._check(
  6135. input_bias.shape == hidden_bias.shape,
  6136. lambda: f"{input_bias.shape} != {hidden_bias.shape}",
  6137. )
  6138. torch._check(prev_hidden.ndim == 2, lambda: f"{prev_hidden.ndim} != 2")
  6139. expected_prev_hidden_numel = input_gates.size(0) * gates_size // factor
  6140. torch._check(
  6141. prev_hidden.numel() == expected_prev_hidden_numel,
  6142. lambda: f"{prev_hidden.numel()} != {input_gates.size(0)} * {gates_size} // {factor} (aka {expected_prev_hidden_numel})",
  6143. )
  6144. torch._check(
  6145. all(
  6146. # pyrefly: ignore [missing-attribute]
  6147. x.device == input_gates.device
  6148. for x in [hidden_gates, input_bias, hidden_bias, prev_hidden]
  6149. ),
  6150. lambda: "expected all inputs to be same device",
  6151. )
  6152. @register_meta(aten._thnn_fused_lstm_cell.default)
  6153. def _thnn_fused_lstm_cell_meta(
  6154. input_gates,
  6155. hidden_gates,
  6156. cx,
  6157. input_bias=None,
  6158. hidden_bias=None,
  6159. ):
  6160. rnn_cell_checkSizes(input_gates, hidden_gates, input_bias, hidden_bias, 4, cx)
  6161. workspace = torch.empty_like(input_gates, memory_format=torch.contiguous_format)
  6162. hy = torch.empty_like(cx, memory_format=torch.contiguous_format)
  6163. cy = torch.empty_like(cx, memory_format=torch.contiguous_format)
  6164. return (hy, cy, workspace)
  6165. @register_meta(aten._cudnn_rnn.default)
  6166. def _cudnn_rnn(
  6167. input,
  6168. weight,
  6169. weight_stride0,
  6170. weight_buf,
  6171. hx,
  6172. cx,
  6173. mode,
  6174. hidden_size,
  6175. proj_size,
  6176. num_layers,
  6177. batch_first,
  6178. dropout,
  6179. train,
  6180. bidirectional,
  6181. batch_sizes,
  6182. dropout_state,
  6183. ):
  6184. is_input_packed = len(batch_sizes) != 0
  6185. if is_input_packed:
  6186. seq_length = len(batch_sizes)
  6187. mini_batch = batch_sizes[0]
  6188. batch_sizes_sum = input.shape[0]
  6189. else:
  6190. seq_length = input.shape[1] if batch_first else input.shape[0]
  6191. mini_batch = input.shape[0] if batch_first else input.shape[1]
  6192. batch_sizes_sum = -1
  6193. num_directions = 2 if bidirectional else 1
  6194. out_size = proj_size if proj_size != 0 else hidden_size
  6195. if is_input_packed:
  6196. out_shape = [batch_sizes_sum, out_size * num_directions]
  6197. else:
  6198. out_shape = (
  6199. [mini_batch, seq_length, out_size * num_directions]
  6200. if batch_first
  6201. else [seq_length, mini_batch, out_size * num_directions]
  6202. )
  6203. output = input.new_empty(out_shape)
  6204. cell_shape = [num_layers * num_directions, mini_batch, hidden_size]
  6205. if cx is None:
  6206. cy = torch.empty(0, device=input.device)
  6207. else:
  6208. cy = cx.new_empty(cell_shape)
  6209. hy = hx.new_empty([num_layers * num_directions, mini_batch, out_size])
  6210. # TODO: Query cudnnGetRNNTrainingReserveSize (expose to python)
  6211. reserve_shape = 0 if train else 0
  6212. reserve = input.new_empty(reserve_shape, dtype=torch.uint8)
  6213. return output, hy, cy, reserve, weight_buf
  6214. @register_meta(aten.mkldnn_rnn_layer.default)
  6215. def mkldnn_rnn_layer(
  6216. input,
  6217. w0,
  6218. w1,
  6219. w2,
  6220. w3,
  6221. hx_,
  6222. cx_,
  6223. reverse,
  6224. batch_sizes,
  6225. mode,
  6226. hidden_size,
  6227. num_layers,
  6228. has_biases,
  6229. bidirectional,
  6230. batch_first,
  6231. train,
  6232. ):
  6233. seq_length = input.shape[1] if batch_first else input.shape[0]
  6234. mini_batch = input.shape[0] if batch_first else input.shape[1]
  6235. output_chanels = hidden_size
  6236. out_shape = (
  6237. [mini_batch, seq_length, output_chanels]
  6238. if batch_first
  6239. else [seq_length, mini_batch, output_chanels]
  6240. )
  6241. output = input.new_empty(out_shape)
  6242. if hx_ is None:
  6243. hy = torch.empty(0, device=input.device)
  6244. else:
  6245. hy = hx_.new_empty(hx_.shape)
  6246. if cx_ is None:
  6247. cy = torch.empty(0, device=input.device)
  6248. else:
  6249. cy = cx_.new_empty(cx_.shape)
  6250. workspace = torch.empty(0, device=input.device, dtype=torch.uint8)
  6251. return output, hy, cy, workspace
  6252. def zero_numel_check_dims(self, dim, fn_name):
  6253. if self.ndim == 0:
  6254. torch._check_index(
  6255. dim == 0 or dim == -1,
  6256. lambda: f"{fn_name}: Expected reduction dim -1 or 0 for scalar but got {dim}",
  6257. )
  6258. else:
  6259. torch._check_index(
  6260. self.size(dim) != 0,
  6261. lambda: f"{fn_name}: Expected reduction dim {dim} to have non-zero size.",
  6262. )
  6263. # From aten/src/ATen/native/ReduceOps.cpp
  6264. def check_argmax_argmin(name, self, dim):
  6265. if dim is not None:
  6266. dim = maybe_wrap_dim(dim, self.dim())
  6267. zero_numel_check_dims(self, dim, name)
  6268. else:
  6269. torch._check(
  6270. self.numel() != 0,
  6271. lambda: f"{name}: Expected reduction dim to be specified for input.numel() == 0.",
  6272. )
  6273. @register_meta([aten.argmax.default, aten.argmin.default])
  6274. def argmax_argmin_meta(self, dim=None, keepdim=False):
  6275. check_argmax_argmin("argmax", self, dim)
  6276. dims = utils.reduction_dims(self.shape, (dim,) if dim is not None else None)
  6277. shape = _compute_reduction_shape(self, dims, keepdim)
  6278. return self.new_empty(shape, dtype=torch.int64)
  6279. @register_meta(aten.scalar_tensor.default)
  6280. def scalar_tensor(s, dtype=None, layout=None, device=None, pin_memory=None):
  6281. # NB: It's always wrong to try to create a scalar tensor with the jagged layout.
  6282. # Rather than fix this everywhere, just use the strided layout and let NJT handle
  6283. # scalar tensor broadcasting.
  6284. if layout == torch.jagged:
  6285. layout = torch.strided
  6286. return torch.empty(
  6287. (), dtype=dtype, layout=layout, device=device, pin_memory=pin_memory
  6288. )
  6289. @register_meta(aten.topk.default)
  6290. def topk_meta(self, k, dim=-1, largest=True, sorted=True):
  6291. # From aten/src/ATen/native/Sorting.cpp
  6292. dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
  6293. sliceSize = 1 if self.dim() == 0 else self.size(dim)
  6294. torch._check(k >= 0)
  6295. torch._check(k <= sliceSize, lambda: "k not in range for dimension")
  6296. topKSize = list(self.shape)
  6297. if len(topKSize) > 0:
  6298. topKSize[dim] = k
  6299. return self.new_empty(topKSize), self.new_empty(topKSize, dtype=torch.int64)
  6300. @register_meta(aten._segment_reduce_backward)
  6301. @out_wrapper()
  6302. def meta__segment_reduce_backward(
  6303. grad, output, data, reduce, lengths=None, offsets=None, axis=0, initial=None
  6304. ):
  6305. assert lengths is not None or offsets is not None, (
  6306. "segment_reduce(): Either lengths or offsets must be defined"
  6307. )
  6308. data_contig = data.contiguous()
  6309. grad_contig = grad.contiguous()
  6310. return torch.empty_like(
  6311. data_contig,
  6312. dtype=grad_contig.dtype,
  6313. device=grad_contig.device,
  6314. layout=grad_contig.layout,
  6315. )
  6316. @register_meta([aten.kthvalue.default, aten.kthvalue.values])
  6317. @out_wrapper("values", "indices")
  6318. def kthvalue_meta(self, k, dim=-1, keepdim=False):
  6319. from torch.fx.experimental.symbolic_shapes import sym_and
  6320. dim = maybe_wrap_dim(dim, self.dim(), wrap_scalar=True)
  6321. dimSize = self.size(dim) if self.dim() > 0 else 1
  6322. torch._check(
  6323. sym_and(k >= 1, k <= dimSize),
  6324. lambda: f"kthvalue(): selected number k out of range for dimension {dim}",
  6325. )
  6326. shape = list(self.shape[:dim] + self.shape[dim + 1 :])
  6327. if keepdim and self.dim() > 0:
  6328. shape.insert(dim, 1)
  6329. return self.new_empty(shape), self.new_empty(shape, dtype=torch.int64)
  6330. legacy_contiguous_memory_format = torch.contiguous_format
  6331. # From aten/src/ATen/native/cuda/RNN.cu
  6332. def checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace):
  6333. defined_grad = grad_hy if grad_hy is not None else grad_cy
  6334. torch._check(defined_grad.dim() == 2, lambda: "")
  6335. exp_size = defined_grad.size()
  6336. if grad_hy is not None:
  6337. torch._check(grad_hy.size() == exp_size, lambda: "")
  6338. if grad_cy is not None:
  6339. torch._check(grad_cy.size() == exp_size, lambda: "")
  6340. torch._check(cx.size() == exp_size, lambda: "")
  6341. torch._check(cy.size() == exp_size, lambda: "")
  6342. torch._check(workspace.dim() == 2, lambda: "")
  6343. torch._check(workspace.numel() == exp_size[0] * exp_size[1] * 4, lambda: "")
  6344. # From aten/src/ATen/native/cuda/RNN.cu
  6345. @register_meta(aten._thnn_fused_lstm_cell_backward_impl.default)
  6346. def _thnn_fused_lstm_cell_backward_impl(grad_hy, grad_cy, cx, cy, workspace, has_bias):
  6347. if grad_hy is None and grad_cy is None:
  6348. return None, None, None
  6349. checkLSTMBackwardSizes(grad_hy, grad_cy, cx, cy, workspace)
  6350. grad_gates = torch.empty_like(
  6351. workspace, memory_format=legacy_contiguous_memory_format
  6352. )
  6353. grad_cx = torch.empty_like(cx, memory_format=legacy_contiguous_memory_format)
  6354. grad_bias = grad_gates.sum(0, keepdim=False) if has_bias else None
  6355. return grad_gates, grad_cx, grad_bias
  6356. # From aten/src/ATen/native/mps/operations/Linear.mm
  6357. @register_meta(aten.linear_backward.default)
  6358. def linear_backward(input_, grad_output_, weight_, output_mask):
  6359. grad_input = None
  6360. grad_weight = None
  6361. grad_bias = None
  6362. if output_mask[0]:
  6363. grad_input = grad_output_.new_empty(input_.size())
  6364. if output_mask[1] or output_mask[2]:
  6365. grad_weight = grad_output_.new_empty((grad_output_.size(-1), input_.size(-1)))
  6366. grad_bias = grad_output_.new_empty(grad_output_.size(-1))
  6367. return (grad_input, grad_weight, grad_bias)
  6368. @register_meta(aten.pixel_shuffle.default)
  6369. def meta_pixel_shuffle(self, upscale_factor):
  6370. assert (
  6371. len(self.shape) > 2 and self.shape[-3] % (upscale_factor * upscale_factor) == 0
  6372. ), (
  6373. f"Invalid input shape for pixel_shuffle: {self.shape} with upscale_factor = {upscale_factor}"
  6374. )
  6375. def is_channels_last(ten):
  6376. return torch._prims_common.suggest_memory_format(ten) == torch.channels_last
  6377. def pick_memory_format():
  6378. if is_channels_last(self):
  6379. if device_hint(self) == "cuda":
  6380. return torch.contiguous_format
  6381. else:
  6382. return torch.channels_last
  6383. elif self.is_contiguous(memory_format=torch.contiguous_format):
  6384. return torch.contiguous_format
  6385. elif self.is_contiguous(memory_format=torch.preserve_format):
  6386. return torch.preserve_format
  6387. C = self.shape[-3] // (upscale_factor * upscale_factor)
  6388. Hr = self.shape[-2] * upscale_factor
  6389. Wr = self.shape[-1] * upscale_factor
  6390. out_shape = (*self.shape[:-3], C, Hr, Wr)
  6391. out = self.new_empty(out_shape)
  6392. out = out.to(memory_format=pick_memory_format()) # type: ignore[call-overload]
  6393. return out
  6394. @register_meta(aten.mkldnn_rnn_layer_backward.default)
  6395. def mkldnn_rnn_layer_backward(
  6396. input,
  6397. weight0,
  6398. weight1,
  6399. weight2,
  6400. weight3,
  6401. hx_,
  6402. cx_tmp,
  6403. output,
  6404. hy_,
  6405. cy_,
  6406. grad_output_r_opt,
  6407. grad_hy_r_opt,
  6408. grad_cy_r_opt,
  6409. reverse,
  6410. mode,
  6411. hidden_size,
  6412. num_layers,
  6413. has_biases,
  6414. train,
  6415. bidirectional,
  6416. batch_sizes,
  6417. batch_first,
  6418. workspace,
  6419. ):
  6420. diff_x = input.new_empty(input.shape)
  6421. diff_hx = hx_.new_empty(hx_.shape)
  6422. diff_cx = cx_tmp.new_empty(cx_tmp.shape)
  6423. diff_w1 = weight0.new_empty(weight0.shape)
  6424. diff_w2 = weight1.new_empty(weight1.shape)
  6425. diff_b = weight2.new_empty(weight2.shape)
  6426. return diff_x, diff_w1, diff_w2, diff_b, diff_b, diff_hx, diff_cx
  6427. @register_meta([aten.bucketize.Tensor, aten.bucketize.Tensor_out])
  6428. @out_wrapper()
  6429. def meta_bucketize(self, boundaries, *, out_int32=False, right=False):
  6430. return torch.empty_like(
  6431. self,
  6432. dtype=torch.int32 if out_int32 else torch.int64,
  6433. memory_format=torch.contiguous_format,
  6434. )
  6435. @register_meta([aten.histc])
  6436. @out_wrapper()
  6437. def meta_histc(input, bins=100, min=0, max=0):
  6438. fn_name = "histc()"
  6439. if device_hint(input) == "cpu":
  6440. torch._check(
  6441. input.is_floating_point(),
  6442. lambda: f"\"histogram_cpu\" not implemented for '{input.dtype}'",
  6443. )
  6444. if device_hint(input) == "cuda" and input.is_floating_point():
  6445. utils.alert_not_deterministic("_histc_cuda with floating point input")
  6446. torch._check(
  6447. isinstance(bins, IntLike),
  6448. lambda: f"{fn_name}: argument 'bins' must be int, not {type(bins)}",
  6449. )
  6450. torch._check(bins > 0, lambda: f"{fn_name}: bins must be > 0, but got {bins}")
  6451. torch._check(
  6452. isinstance(min, Number),
  6453. lambda: f"{fn_name}: argument 'min' must be Number, not {type(min)}",
  6454. )
  6455. torch._check(
  6456. isinstance(max, Number),
  6457. lambda: f"{fn_name}: argument 'max' must be Number, not {type(max)}",
  6458. )
  6459. torch._check(max >= min, lambda: f"{fn_name}: max must be larger than min")
  6460. return torch.empty(bins, device=input.device, dtype=input.dtype)
  6461. @register_meta(
  6462. [aten._upsample_bilinear2d_aa.default, aten._upsample_bicubic2d_aa.default]
  6463. )
  6464. def meta_upsample_bimode2d_aa(
  6465. input,
  6466. output_size,
  6467. align_corners,
  6468. scales_h=None,
  6469. scales_w=None,
  6470. ):
  6471. full_output_size = upsample_common_check(
  6472. input.size(), output_size, num_spatial_dims=2
  6473. )
  6474. torch._check(
  6475. input.numel() != 0 or all(size > 0 for size in input.size()[1:]),
  6476. lambda: f"Non-empty 4D data tensor expected but got a tensor with sizes {input.size()}",
  6477. )
  6478. return input.new_empty(full_output_size).to(
  6479. memory_format=utils.suggest_memory_format(input)
  6480. )
  6481. @register_meta([aten._upsample_bilinear2d_aa_backward.default])
  6482. def meta_upsample_bimode2d_aa_backward(
  6483. grad_output,
  6484. output_size,
  6485. input_size,
  6486. align_corners,
  6487. scales_h=None,
  6488. scales_w=None,
  6489. ):
  6490. full_output_size = upsample_common_check(
  6491. input_size, output_size, num_spatial_dims=2
  6492. )
  6493. torch._check(
  6494. grad_output.ndim == 4,
  6495. lambda: f"Expected grad_output to be a tensor of dimension 4 but got: dimension {grad_output.ndim}",
  6496. )
  6497. for i in range(4):
  6498. torch._check(
  6499. grad_output.shape[i] == full_output_size[i],
  6500. lambda: f"""
  6501. Expected grad_output to have the same shape as output; output.size({i}) = {full_output_size[i]}
  6502. but got grad_output_size({i}) = {grad_output.size(i)}""",
  6503. )
  6504. return grad_output.new_empty(input_size).to(
  6505. memory_format=utils.suggest_memory_format(grad_output)
  6506. )
  6507. # From aten/src/ATen/native/cuda/AmpKernels.cu
  6508. @register_meta(aten._amp_foreach_non_finite_check_and_unscale_.default)
  6509. def _amp_foreach_non_finite_check_and_unscale_(self, found_inf, inv_scale):
  6510. torch._check(
  6511. found_inf.numel() == 1, lambda: "found_inf must be a 1-element tensor."
  6512. )
  6513. torch._check(
  6514. inv_scale.numel() == 1, lambda: "inv_scale must be a 1-element tensor."
  6515. )
  6516. torch._check(
  6517. found_inf.dtype.is_floating_point,
  6518. lambda: "found_inf must be a float tensor.",
  6519. )
  6520. torch._check(
  6521. inv_scale.dtype.is_floating_point,
  6522. lambda: "inv_scale must be a float tensor.",
  6523. )
  6524. # From aten/src/ATen/native/UnaryOps.cpp
  6525. @register_meta([aten.nan_to_num.default, aten.nan_to_num.out])
  6526. @out_wrapper()
  6527. def nan_to_num(self, nan=None, posinf=None, neginf=None):
  6528. return torch.empty_like(self)
  6529. @register_meta(torch.ops.aten.transpose_)
  6530. def transpose_(self, dim0, dim1):
  6531. assert self.layout not in {
  6532. torch.sparse_csr,
  6533. torch.sparse_csc,
  6534. torch.sparse_bsr,
  6535. torch.sparse_bsc,
  6536. }, (
  6537. f"torch.transpose_: in-place transposition is not supported for {self.layout} layout"
  6538. )
  6539. ndims = self.ndim
  6540. dim0 = maybe_wrap_dim(dim0, ndims)
  6541. dim1 = maybe_wrap_dim(dim1, ndims)
  6542. if dim0 == dim1:
  6543. return self
  6544. size = list(self.size())
  6545. stride = list(self.stride())
  6546. stride[dim0], stride[dim1] = stride[dim1], stride[dim0]
  6547. size[dim0], size[dim1] = size[dim1], size[dim0]
  6548. self.as_strided_(size, stride)
  6549. return self
  6550. @register_meta(torch.ops.aten.t_)
  6551. def t_(self):
  6552. ndims = self.ndim
  6553. if self.is_sparse:
  6554. sparse_dim = self.sparse_dim()
  6555. dense_dim = self.dense_dim()
  6556. assert sparse_dim <= 2 and dense_dim == 0, (
  6557. f"t_ expects a tensor with <= 2 sparse and 0 dense dimensions, "
  6558. f"but got {sparse_dim} sparse and {dense_dim} dense dimensions"
  6559. )
  6560. else:
  6561. assert self.dim() <= 2, (
  6562. f"t_ expects a tensor with <= 2 dimensions, but self is {ndims}D"
  6563. )
  6564. return transpose_(self, 0, 0 if ndims < 2 else 1)
  6565. @register_meta(aten.searchsorted)
  6566. @out_wrapper()
  6567. def meta_searchsorted(
  6568. sorted_sequence,
  6569. self,
  6570. *,
  6571. out_int32=False,
  6572. right=False,
  6573. side=None,
  6574. sorter=None,
  6575. ):
  6576. # If the sorted_sequence is not one-dimensional, its shape must match that of values
  6577. # in all but the last dimension.
  6578. torch._check(
  6579. len(sorted_sequence.shape) <= 1
  6580. or sorted_sequence.shape[:-1] == self.shape[:-1],
  6581. lambda: (
  6582. "torch.searchsorted(): boundaries tensor should be 1 dimension or the "
  6583. "first N-1 dimensions of boundaries tensor and input value tensor must "
  6584. f"match, but we got boundaries tensor {list(sorted_sequence.shape)} and "
  6585. f"input value tensor {list(self.shape)}"
  6586. ),
  6587. )
  6588. # If a sorter array is provided, its dimensions must exactly match sorted_sequence.
  6589. torch._check(
  6590. sorter is None or sorted_sequence.shape == sorter.shape,
  6591. lambda: (
  6592. "torch.searchsorted(): boundary and sorter must have the same size, but "
  6593. f"got boundary tensor {list(sorted_sequence.shape)} and got sorter tensor "
  6594. f"{list(sorter.shape) if sorter is not None else []}"
  6595. ),
  6596. )
  6597. # Per the docs, if side == "left" and right is True, we error.
  6598. torch._check(
  6599. side != "left" or not right,
  6600. lambda: "torch.searchsorted(): side and right can't be set to opposites, got side of "
  6601. "left while right was True",
  6602. )
  6603. dtype = torch.int32 if out_int32 else torch.int64
  6604. if isinstance(self, torch.Tensor):
  6605. return torch.empty_like(
  6606. self, dtype=dtype, memory_format=torch.contiguous_format
  6607. )
  6608. else: # Scalar
  6609. return torch.empty((), dtype=dtype, device=sorted_sequence.device)
  6610. def _check_for_unsupported_isin_dtype(dtype):
  6611. torch._check(
  6612. dtype not in (torch.bool, torch.complex128, torch.complex64),
  6613. lambda: f"Unsupported input type encountered for isin(): {dtype}",
  6614. )
  6615. @register_meta(aten.embedding_dense_backward)
  6616. def meta_embedding_dense_backward(
  6617. grad_output,
  6618. indices,
  6619. num_weights,
  6620. padding_idx,
  6621. scale_grad_by_freq,
  6622. ):
  6623. grad_weight = grad_output.new_empty((num_weights, grad_output.size(-1)))
  6624. return grad_weight
  6625. @register_meta(aten._embedding_bag_backward)
  6626. def meta_embedding_bag_backward(
  6627. grad,
  6628. indices,
  6629. offsets,
  6630. offset2bag,
  6631. bag_size,
  6632. maximum_indices,
  6633. num_weights,
  6634. scale_grad_by_freq,
  6635. mode,
  6636. sparse,
  6637. per_sample_weights,
  6638. padding_idx=-1,
  6639. ):
  6640. if sparse:
  6641. return aten._embedding_bag_sparse_backward(
  6642. grad,
  6643. indices,
  6644. offsets,
  6645. offset2bag,
  6646. bag_size,
  6647. num_weights,
  6648. scale_grad_by_freq,
  6649. mode,
  6650. per_sample_weights,
  6651. padding_idx,
  6652. )
  6653. else:
  6654. return meta_embedding_bag_dense_backward(
  6655. grad,
  6656. indices,
  6657. offset2bag,
  6658. bag_size,
  6659. maximum_indices,
  6660. num_weights,
  6661. scale_grad_by_freq,
  6662. mode,
  6663. per_sample_weights,
  6664. padding_idx,
  6665. )
  6666. @register_meta(aten._embedding_bag_dense_backward)
  6667. def meta_embedding_bag_dense_backward(
  6668. grad,
  6669. indices,
  6670. offset2bag,
  6671. bag_size,
  6672. maximum_indices,
  6673. num_weights,
  6674. scale_grad_by_freq,
  6675. mode,
  6676. per_sample_weights,
  6677. padding_idx=-1,
  6678. ):
  6679. torch._check(
  6680. grad.dtype in [torch.float16, torch.bfloat16, torch.float32, torch.float64],
  6681. lambda: f"Unsupported input type encountered: {grad.dtype}",
  6682. )
  6683. if mode == MODE_MAX:
  6684. torch._check(maximum_indices is not None)
  6685. index_grad_weight = grad.new_empty((num_weights, grad.size(1)))
  6686. return index_grad_weight
  6687. @register_meta(aten._embedding_bag_per_sample_weights_backward)
  6688. def meta_embedding_bag_per_sample_weights_backward(
  6689. grad,
  6690. weight,
  6691. indices,
  6692. offsets,
  6693. offset2bag,
  6694. mode,
  6695. padding_idx=-1,
  6696. ):
  6697. embedding_features = grad.size(1)
  6698. torch._check(
  6699. mode == MODE_SUM,
  6700. lambda: "embedding_bag_backward: per_sample_weights only supported for mode='sum'",
  6701. )
  6702. torch._check(grad.dim() == 2)
  6703. torch._check(indices.dim() == 1)
  6704. num_samples = indices.size(0)
  6705. torch._check(weight.dim() == 2)
  6706. torch._check(weight.size(1) == embedding_features)
  6707. output = grad.new_empty((num_samples,))
  6708. return output
  6709. @register_meta(aten.isin)
  6710. @out_wrapper()
  6711. def meta_isin(elements, test_elements, *, assume_unique=False, invert=False):
  6712. torch._check(
  6713. isinstance(elements, Tensor) or isinstance(test_elements, Tensor),
  6714. lambda: "At least one of elements and test_elements must be a Tensor.",
  6715. )
  6716. if not isinstance(elements, Tensor):
  6717. elements = torch.tensor(elements, device=test_elements.device)
  6718. if not isinstance(test_elements, Tensor):
  6719. test_elements = torch.tensor(test_elements, device=elements.device)
  6720. _check_for_unsupported_isin_dtype(elements.dtype)
  6721. _check_for_unsupported_isin_dtype(test_elements.dtype)
  6722. return torch.empty_like(elements, dtype=torch.bool)
  6723. @register_meta(aten.polygamma)
  6724. @out_wrapper()
  6725. def meta_polygamma(n: int, self: Tensor) -> Tensor:
  6726. torch._check(n >= 0, lambda: "polygamma(n, x) does not support negative n.")
  6727. _, result_dtype = elementwise_dtypes(
  6728. self,
  6729. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  6730. )
  6731. return torch.empty_like(self, dtype=result_dtype)
  6732. @register_meta(aten._local_scalar_dense)
  6733. def meta_local_scalar_dense(self: Tensor):
  6734. raise RuntimeError("Tensor.item() cannot be called on meta tensors")
  6735. @register_meta(aten.silu)
  6736. @out_wrapper(exact_dtype=True)
  6737. def silu(self: Tensor) -> Tensor:
  6738. return torch.empty_like(self)
  6739. @register_meta(aten.sigmoid)
  6740. @out_wrapper()
  6741. def sigmoid(self: Tensor) -> Tensor:
  6742. _, result_dtype = elementwise_dtypes(
  6743. self,
  6744. type_promotion_kind=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT,
  6745. )
  6746. return torch.empty_like(self, dtype=result_dtype)
  6747. def _create_grouped_mm_output_tensor(mat1, mat2, offs, out_dtype):
  6748. mat1_is_2d = mat1.dim() == 2
  6749. mat2_is_2d = mat2.dim() == 2
  6750. if mat1_is_2d:
  6751. if mat2_is_2d:
  6752. out_size = [offs.size(0), mat1.size(0), mat2.size(1)]
  6753. else:
  6754. torch._check(
  6755. offs.size(0) == mat2.size(0), lambda: "matrix batch sizes have to match"
  6756. )
  6757. out_size = [mat1.size(0), mat2.size(-1)]
  6758. else:
  6759. if mat2_is_2d:
  6760. torch._check(
  6761. offs.size(0) == mat1.size(0), lambda: "matrix batch sizes have to match"
  6762. )
  6763. out_size = [mat1.size(1), mat2.size(1)]
  6764. else:
  6765. # regular bmm
  6766. torch._check(
  6767. mat1.size(0) == mat2.size(0), lambda: "batched dimension has to match"
  6768. )
  6769. out_size = [mat1.size(0), mat1.size(1), mat2.size(-1)]
  6770. out_dtype = out_dtype or mat1.dtype
  6771. if torch.version.cuda:
  6772. alignment = 16 // out_dtype.itemsize
  6773. size_padded = (out_size[-1] + alignment - 1) // alignment * alignment
  6774. if mat1_is_2d == mat2_is_2d:
  6775. out_stride = [out_size[1] * size_padded, size_padded, 1]
  6776. else:
  6777. out_stride = [size_padded, 1]
  6778. out = torch.empty_strided(
  6779. out_size, out_stride, dtype=out_dtype, device=mat1.device
  6780. )
  6781. else:
  6782. out = torch.empty(out_size, dtype=out_dtype, device=mat1.device)
  6783. return out
  6784. def _meta_grouped_mm_common(
  6785. mat_a: Tensor,
  6786. mat_b: Tensor,
  6787. scale_a: torch.Tensor | None,
  6788. scale_b: torch.Tensor | None,
  6789. offs: Tensor | None = None,
  6790. bias: Tensor | None = None,
  6791. scale_result: torch.Tensor | None = None,
  6792. out_dtype: torch.dtype | None = None,
  6793. use_fast_accum: bool = False,
  6794. ):
  6795. torch._check(
  6796. (scale_a is None) == (scale_b is None),
  6797. lambda: "Either both scale factors are given, or none",
  6798. )
  6799. scaled = scale_a is not None and scale_b is not None
  6800. # Implementing all the checks from
  6801. # _grouped_mm_cuda()/_scaled_grouped_mm_cuda() code in
  6802. # aten/src/ATen/native/cuda/Blas.cpp.
  6803. if scaled:
  6804. fp8_dtype = torch.float8_e4m3fnuz if torch.version.hip else torch.float8_e4m3fn
  6805. torch._check(
  6806. mat_a.dtype == fp8_dtype and mat_b.dtype == fp8_dtype,
  6807. lambda: f"Expected inputs of E4M3 FP8 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
  6808. )
  6809. else:
  6810. torch._check(
  6811. mat_a.dtype == torch.bfloat16 and mat_b.dtype == torch.bfloat16,
  6812. lambda: f"Expected inputs of BF16 type but got mat_a.dtype={mat_a.dtype} and mat_b.dtype={mat_b.dtype}.", # noqa: B950
  6813. )
  6814. torch._check(
  6815. mat_a.dim() in [2, 3] and mat_b.dim() in [2, 3],
  6816. lambda: f"Multiplicands must be 2D or 3D but got mat_a.dim()={mat_a.dim()} and mat_b.dim()={mat_b.dim()}", # noqa: B950
  6817. )
  6818. mat_a_is_2d = mat_a.dim() == 2
  6819. mat_b_is_2d = mat_b.dim() == 2
  6820. if not mat_a_is_2d or not mat_b_is_2d:
  6821. torch._check(
  6822. mat_a.size(-1) == mat_b.size(-2),
  6823. lambda: "contraction dimension of mat_a and mat_b must match",
  6824. )
  6825. if scaled:
  6826. def is_row_major(mat):
  6827. mat_stride = mat.stride()
  6828. return mat_stride[-2] > 1 and mat_stride[-1] == 1
  6829. def is_col_major(mat):
  6830. mat_stride = mat.stride()
  6831. return mat_stride[-2] == 1 and mat_stride[-1] > 1
  6832. torch._check(
  6833. is_row_major(mat_a),
  6834. lambda: f"Expected mat_a tensor to be row major in the last two dimensions, got strides {mat_a.stride()[-2:]}", # noqa: B950
  6835. )
  6836. torch._check(
  6837. is_col_major(mat_b),
  6838. lambda: f"Expected mat_b tensor to be column major in the last two dimensions, got strides {mat_b.stride()[-2:]}", # noqa: B950
  6839. )
  6840. def check_valid_strides(mat_name, mat):
  6841. end_dim = mat.dim() - 1
  6842. alignment = 16 // mat.element_size()
  6843. mat_stride = mat.stride()
  6844. if mat_stride[end_dim - 1] == 1 and mat_stride[end_dim] >= max(
  6845. 1, mat.shape[end_dim - 1]
  6846. ):
  6847. torch._check(
  6848. mat_stride[end_dim] % alignment == 0,
  6849. lambda: f"Expected {mat_name} stride along {end_dim} dim to be multiple of 16 bytes, got {mat_stride[end_dim]}.", # noqa: B950
  6850. )
  6851. elif mat_stride[end_dim] == 1 and mat_stride[end_dim - 1] >= max(
  6852. 1, mat.shape[end_dim]
  6853. ):
  6854. torch._check(
  6855. mat_stride[end_dim - 1] % alignment == 0,
  6856. lambda: f"Expected {mat_name} stride along {end_dim - 1} dim to be multiple of 16 bytes, got {mat_stride[end_dim - 1]}.", # noqa: B950
  6857. )
  6858. else:
  6859. torch._check(
  6860. False,
  6861. lambda: f"Invalid strides/sizes, got {mat_stride} for strides and {mat.shape} for sizes.", # noqa: B950
  6862. )
  6863. check_valid_strides("mat_a", mat_a)
  6864. check_valid_strides("mat_b", mat_b)
  6865. if scale_a is not None and scale_b is not None:
  6866. torch._check(
  6867. (scale_a.dtype == torch.float32 and scale_b.dtype == torch.float32)
  6868. or (
  6869. scale_a.dtype == torch.float8_e8m0fnu
  6870. and scale_b.dtype == torch.float8_e8m0fnu
  6871. ),
  6872. lambda: f"For FP8 scales must both be float32, or for MXFP8 both scales must be float8_e8m0fnu. Got scale_a.dtype={scale_a.dtype} and scale_b.dtype={scale_b.dtype}.", # noqa: B950
  6873. )
  6874. is_mxfp8 = (
  6875. scale_a.dtype == torch.float8_e8m0fnu
  6876. and scale_b.dtype == torch.float8_e8m0fnu
  6877. )
  6878. def check_scale(scale_name, scale, mat, scaled_dim, scale_multiplier=1):
  6879. if mat.dim() == 2:
  6880. torch._check(
  6881. scale.is_contiguous(),
  6882. lambda: f"Expected {scale_name} to be contiguous.",
  6883. )
  6884. # For MXFP8, 2d tensors have variable size groups represented as subtensors,
  6885. # that are converted to blocked padded format individually. At compile time we don't know
  6886. # the group sizes yet, so we don't know the expect size of the blocked format scale.
  6887. # This limits what we can check here.
  6888. if is_mxfp8:
  6889. torch._check(
  6890. scale.dim() == mat.dim(),
  6891. lambda: f"For MXFP8, scale must have same number of dimensions as target tensor, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
  6892. )
  6893. else:
  6894. torch._check(
  6895. scale.dim() == 1,
  6896. lambda: f"Expected {scale_name} to be 1D tensor, but got {scale.dim()}D tensor.",
  6897. )
  6898. torch._check(
  6899. scale.shape[0] == mat.shape[scaled_dim] * scale_multiplier,
  6900. lambda: f"Expected {scale_name} to have {mat.shape[scaled_dim] * scale_multiplier} elements, got {scale.shape[0]} elements.", # noqa: B950
  6901. )
  6902. else:
  6903. torch._check(
  6904. scale.stride(-1) == 1,
  6905. lambda: f"Expected {scale_name} to be contiguous in the last dimension.",
  6906. )
  6907. torch._check(
  6908. scale.shape[0] == mat.shape[0],
  6909. lambda: f"Expected {scale_name} batch dimension to be {mat.shape[0]}, got {scale.shape[0]}.",
  6910. )
  6911. # For MXFP8, 3d tensors have static 'groups' (stack of 2d tensors) so we can know the expected blocked
  6912. # scale sizes at compile time.
  6913. if is_mxfp8:
  6914. torch._check(
  6915. scale.ndim == mat.ndim - 1,
  6916. lambda: f"For MXFP8, 3d tensor should have 2d scales, but {scale_name} has mat.ndim={mat.ndim} and scale.ndim={scale.ndim}", # noqa: B950
  6917. )
  6918. # TODO: This logic only holds for RHS tensor in 2d-3d case.
  6919. # We'll need to update it to handle LHS 3d tensor in 3d-2d and 3d-3d cases.
  6920. G, K, N = mat.shape
  6921. block_size = 32
  6922. blocked_K = round_up(K / block_size, 4)
  6923. blocked_N = round_up(N, 128)
  6924. torch._check(
  6925. scale.shape[0] == G and scale.shape[1] == blocked_K * blocked_N,
  6926. lambda: f"For MXFP8, expected mat.shape={mat.shape} to have scale shape of ({G},{blocked_K * blocked_N}), but got {scale.shape}", # noqa: B950
  6927. )
  6928. else:
  6929. torch._check(
  6930. scale.dim() == 2,
  6931. lambda: f"Expected {scale_name} to be 2D tensor, but got {scale.dim()}D tensor.",
  6932. )
  6933. torch._check(
  6934. scale.shape[1] == mat.shape[1 + scaled_dim],
  6935. lambda: f"Expected {scale_name} non-batch dimension to be {mat.shape[1 + scaled_dim]}, got {scale.shape[1]}.", # noqa: B950
  6936. )
  6937. scale_multiplier = (
  6938. offs.shape[0] if offs is not None and mat_a_is_2d and mat_b_is_2d else 1
  6939. )
  6940. check_scale("scale_a", scale_a, mat_a, 0, scale_multiplier)
  6941. check_scale("scale_b", scale_b, mat_b, 1, scale_multiplier)
  6942. torch._check(
  6943. scale_result is None,
  6944. lambda: "Scale result tensor provided, but it is not supported yet.",
  6945. )
  6946. if mat_a_is_2d or mat_b_is_2d:
  6947. torch._check(
  6948. offs is not None,
  6949. lambda: f"Offsets tensor not provided, but is needed for {mat_a.dim()}D/{mat_b.dim()}D multiplicand layouts.",
  6950. )
  6951. if offs is not None: # to silence Mypy
  6952. torch._check(
  6953. offs.dim() == 1,
  6954. lambda: f"Offsets tensor must be 1D, but got offs.dim()={offs.dim()}.",
  6955. )
  6956. torch._check(
  6957. offs.dtype == torch.int32,
  6958. lambda: f"Offsets tensor must be integer (int32) tensor, but got {offs.dtype}.",
  6959. )
  6960. else:
  6961. torch._check(
  6962. offs is None,
  6963. lambda: "Offsets tensor provided, but is not needed for 3D/3D multiplicand layouts.",
  6964. )
  6965. torch._check(
  6966. bias is None,
  6967. lambda: "Bias tensor provided, but it is not supported yet.",
  6968. )
  6969. torch._check(
  6970. out_dtype is None or out_dtype == torch.bfloat16,
  6971. lambda: "If output dtype provided, it must be torch.bfloat16.",
  6972. )
  6973. return _create_grouped_mm_output_tensor(mat_a, mat_b, offs, out_dtype)
  6974. @register_meta(aten._grouped_mm)
  6975. @out_wrapper()
  6976. def meta_grouped_mm(
  6977. mat_a: Tensor,
  6978. mat_b: Tensor,
  6979. offs: Tensor | None = None,
  6980. bias: Tensor | None = None,
  6981. out_dtype: torch.dtype | None = None,
  6982. ) -> Tensor:
  6983. return _meta_grouped_mm_common(
  6984. mat_a,
  6985. mat_b,
  6986. scale_a=None,
  6987. scale_b=None,
  6988. offs=offs,
  6989. bias=bias,
  6990. scale_result=None,
  6991. out_dtype=out_dtype,
  6992. )
  6993. @register_meta([aten._scaled_grouped_mm])
  6994. def meta_scaled_grouped_mm(
  6995. mat_a: torch.Tensor,
  6996. mat_b: torch.Tensor,
  6997. scale_a: torch.Tensor,
  6998. scale_b: torch.Tensor,
  6999. offs: torch.Tensor | None = None,
  7000. bias: torch.Tensor | None = None,
  7001. scale_result: torch.Tensor | None = None,
  7002. out_dtype: torch.dtype | None = None,
  7003. use_fast_accum: bool = False,
  7004. ):
  7005. # matching _scaled_grouped_mm_cuda Blas.cpp implementation
  7006. out_dtype = out_dtype or torch.bfloat16
  7007. return _meta_grouped_mm_common(
  7008. mat_a,
  7009. mat_b,
  7010. scale_a=scale_a,
  7011. scale_b=scale_b,
  7012. offs=offs,
  7013. bias=bias,
  7014. scale_result=scale_result,
  7015. out_dtype=out_dtype,
  7016. use_fast_accum=use_fast_accum,
  7017. )
  7018. @register_meta(aten._softmax)
  7019. @out_wrapper()
  7020. def softmax(x: Tensor, dim: int, half_to_float: bool) -> Tensor:
  7021. if half_to_float:
  7022. assert x.dtype in [torch.half, torch.bfloat16]
  7023. computation_dtype, result_dtype = utils.elementwise_dtypes(
  7024. x, type_promotion_kind=utils.ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  7025. )
  7026. result_dtype = result_dtype if not half_to_float else computation_dtype
  7027. res = torch.empty_like(x, dtype=result_dtype, memory_format=torch.contiguous_format)
  7028. return res
  7029. @register_meta(aten.constant_pad_nd)
  7030. @out_wrapper()
  7031. def _constant_pad_nd_meta(input, pad, value=0):
  7032. # same checks as decomposition in torch/_refs/__init__.py:constant_pad_nd()
  7033. torch._check(
  7034. len(pad) % 2 == 0,
  7035. lambda: f"Length of pad must be even but instead it equals {len(pad)}",
  7036. )
  7037. input_sizes = input.shape
  7038. l_inp = len(input_sizes)
  7039. l_pad = len(pad) // 2
  7040. l_diff = l_inp - l_pad
  7041. torch._check(
  7042. l_inp >= l_pad,
  7043. lambda: "Length of pad should be no more than twice the number of "
  7044. f"dimensions of the input. Pad length is {len(pad)} while the input has "
  7045. f"{l_inp} dimensions.",
  7046. )
  7047. if all(isinstance(p, utils.IntWithoutSymInt) and p <= 0 for p in pad):
  7048. c_input = input
  7049. for i in range(l_diff, l_inp):
  7050. pad_idx = 2 * (l_inp - i - 1)
  7051. if pad[pad_idx] < 0:
  7052. c_input = c_input.narrow(
  7053. i, -pad[pad_idx], c_input.shape[i] + pad[pad_idx]
  7054. )
  7055. if pad[pad_idx + 1] < 0:
  7056. c_input = c_input.narrow(i, 0, c_input.shape[i] + pad[pad_idx + 1])
  7057. return c_input.clone()
  7058. new_shape = list(input_sizes[:l_diff])
  7059. for i in range(l_pad):
  7060. pad_idx = len(pad) - ((i + 1) * 2)
  7061. new_dim = input_sizes[l_diff + i] + pad[pad_idx] + pad[pad_idx + 1]
  7062. torch._check(
  7063. new_dim >= 0,
  7064. lambda: f"The input size {input_sizes[l_diff + i]}, plus negative padding "
  7065. f"{pad[pad_idx]} and {pad[pad_idx + 1]} resulted in a negative output size, "
  7066. f"which is invalid. Check dimension {l_diff + i} of your input.",
  7067. )
  7068. new_shape.append(new_dim)
  7069. return torch.empty(
  7070. new_shape,
  7071. dtype=input.dtype,
  7072. device=input.device,
  7073. requires_grad=input.requires_grad,
  7074. memory_format=suggest_memory_format(input),
  7075. )
  7076. @register_meta(aten.embedding)
  7077. @out_wrapper()
  7078. def embedding(
  7079. weight: Tensor,
  7080. indices: Tensor,
  7081. padding_idx: int = -1,
  7082. scale_grad_by_freq: bool = False,
  7083. sparse: bool = False,
  7084. ) -> Tensor:
  7085. assert weight.dim() == 2, "'weight' must be 2-D"
  7086. weight_shape = weight.shape
  7087. indices_shape = indices.shape
  7088. if indices.ndim == 0:
  7089. out_shape: tuple[int, ...] = (weight_shape[1],)
  7090. elif indices.ndim == 1:
  7091. out_shape = (indices_shape[0], weight_shape[1])
  7092. else:
  7093. out_shape = (*indices_shape, weight_shape[1])
  7094. out_dtype = weight.dtype
  7095. return weight.new_empty(out_shape, dtype=out_dtype)
  7096. @register_meta(aten._jagged_to_padded_dense_forward.default)
  7097. def meta__jagged_to_padded_dense_forward(
  7098. values: Tensor,
  7099. offsets: list[Tensor],
  7100. max_lengths: list[int],
  7101. padding_value: float = 0.0,
  7102. ):
  7103. # only one jagged dim is supported for now
  7104. assert len(offsets) == 1
  7105. assert len(max_lengths) == 1
  7106. B = offsets[0].shape[0] - 1
  7107. S = max_lengths[0]
  7108. output_shape = (B, S, *values.shape[1:])
  7109. return values.new_empty(output_shape)
  7110. def _create_unary_float_meta_func(func):
  7111. @register_meta(func)
  7112. @out_wrapper()
  7113. def _f(x):
  7114. return elementwise_meta(
  7115. x, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  7116. )
  7117. return _f
  7118. # Implementation follows cuda implementation native_multi_head_attention_cuda
  7119. @register_meta(aten._native_multi_head_attention.default)
  7120. def native_multi_head_attention_fake(
  7121. query,
  7122. key,
  7123. value,
  7124. embed_dim,
  7125. num_head,
  7126. qkv_weight,
  7127. qkv_bias,
  7128. proj_weight,
  7129. proj_bias,
  7130. mask=None,
  7131. need_weights=True,
  7132. average_attn_weights=True,
  7133. mask_type=None,
  7134. ):
  7135. if query.is_nested or key.is_nested or value.is_nested:
  7136. raise NotImplementedError(
  7137. "_native_multi_head_attention fake implementation does not support nested tensors"
  7138. )
  7139. if query.numel() == 0:
  7140. return (query.new_empty(query.shape), query.new_empty(0))
  7141. B = query.size(0) # B: batch size
  7142. T = query.size(1) # T: target sequence length
  7143. # In native_multi_head_attention_cuda,
  7144. # we have proj = transform0213_gemm_nt_bias(attn_ctx, proj_weight, proj_bias, query)
  7145. # , which does attn_ctx @ proj_weight.T + proj_bias
  7146. # so the last dim of output shape is proj_weight.size(0)
  7147. output_dim = proj_weight.size(0)
  7148. output = query.new_empty(B, T, output_dim)
  7149. if need_weights:
  7150. if average_attn_weights:
  7151. # When averaging attention weights, shape is [B, T, T] (averaged over heads)
  7152. # T = query seq len, S = key/value seq len
  7153. attn_weights = query.new_empty(B, T, T)
  7154. else:
  7155. # When not averaging, shape is [B, num_head, T, T]
  7156. # T = query seq len, S = key/value seq len
  7157. attn_weights = query.new_empty(B, num_head, T, T)
  7158. else:
  7159. attn_weights = query.new_empty(0)
  7160. return (output, attn_weights)
  7161. def _create_binary_float_meta_func(func):
  7162. @register_meta(func)
  7163. @out_wrapper()
  7164. def _f(x, y):
  7165. return elementwise_meta(
  7166. x, y, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.INT_TO_FLOAT
  7167. )
  7168. return _f
  7169. _create_unary_float_meta_func(aten.special_airy_ai)
  7170. _create_unary_float_meta_func(aten.special_bessel_y0)
  7171. _create_unary_float_meta_func(aten.special_bessel_y1)
  7172. _create_unary_float_meta_func(aten.special_modified_bessel_i0)
  7173. _create_unary_float_meta_func(aten.special_modified_bessel_i1)
  7174. _create_unary_float_meta_func(aten.special_modified_bessel_k0)
  7175. _create_unary_float_meta_func(aten.special_modified_bessel_k1)
  7176. _create_unary_float_meta_func(aten.special_scaled_modified_bessel_k0)
  7177. _create_unary_float_meta_func(aten.special_scaled_modified_bessel_k1)
  7178. _create_binary_float_meta_func(aten.special_chebyshev_polynomial_t)
  7179. _create_binary_float_meta_func(aten.special_chebyshev_polynomial_u)
  7180. _create_binary_float_meta_func(aten.special_chebyshev_polynomial_v)
  7181. _create_binary_float_meta_func(aten.special_chebyshev_polynomial_w)
  7182. _create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_t)
  7183. _create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_u)
  7184. _create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_v)
  7185. _create_binary_float_meta_func(aten.special_shifted_chebyshev_polynomial_w)
  7186. _create_binary_float_meta_func(aten.special_hermite_polynomial_h)
  7187. _create_binary_float_meta_func(aten.special_hermite_polynomial_he)
  7188. _create_binary_float_meta_func(aten.special_laguerre_polynomial_l)
  7189. _create_binary_float_meta_func(aten.special_legendre_polynomial_p)
  7190. def _register_inplace_meta(fn):
  7191. @wraps(fn)
  7192. def _fn(self, *args, **kwargs):
  7193. out = fn(self, *args, **kwargs)
  7194. check_inplace_broadcast(self.shape, out.shape)
  7195. return self
  7196. inplace_name = f"{fn.__name__}_"
  7197. _fn.__name__ = inplace_name
  7198. _fn = register_meta(getattr(aten, inplace_name))(_fn) # type: ignore[assignment]
  7199. return _fn
  7200. @register_meta(aten.lerp)
  7201. @out_wrapper()
  7202. def lerp(start, end, weight):
  7203. torch._check(
  7204. start.dtype == end.dtype,
  7205. lambda: f"expected dtype {start.dtype} for `end`, but got dtype {end.dtype}",
  7206. )
  7207. args = [start, end]
  7208. if isinstance(weight, TensorLike):
  7209. if weight.ndim != 0:
  7210. torch._check(
  7211. start.dtype == weight.dtype,
  7212. lambda: f"expected dtype {start.dtype} for `weight`, but got dtype {weight.dtype}",
  7213. )
  7214. args.append(weight)
  7215. return elementwise_meta(
  7216. *args, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  7217. )
  7218. @register_meta(aten.addcmul)
  7219. @out_wrapper()
  7220. def addcmul(input, tensor1, tensor2, *, value=1):
  7221. return elementwise_meta(
  7222. input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  7223. )
  7224. @register_meta(aten.addcdiv)
  7225. @out_wrapper()
  7226. def addcdiv(input, tensor1, tensor2, *, value=1):
  7227. torch._check(
  7228. not (
  7229. utils.is_integer_dtype(tensor1.dtype)
  7230. and utils.is_integer_dtype(tensor2.dtype)
  7231. ),
  7232. lambda: (
  7233. "Integer division with addcdiv is no longer supported, and in a future ",
  7234. "release addcdiv will perform a true division of tensor1 and tensor2. ",
  7235. "The historic addcdiv behavior can be implemented as ",
  7236. "(input + value * torch.trunc(tensor1 / tensor2)).to(input.dtype) ",
  7237. "for integer inputs and as ",
  7238. "(input + value * tensor1 / tensor2) for float inputs. ",
  7239. "The future addcdiv behavior is just the latter implementation: ",
  7240. "(input + value * tensor1 / tensor2), for all dtypes.",
  7241. ),
  7242. )
  7243. return elementwise_meta(
  7244. input, tensor1, tensor2, type_promotion=ELEMENTWISE_TYPE_PROMOTION_KIND.DEFAULT
  7245. )
  7246. lerp_ = _register_inplace_meta(aten.lerp)
  7247. addcmul_ = _register_inplace_meta(aten.addcmul)
  7248. addcdiv_ = _register_inplace_meta(aten.addcdiv)
  7249. # We must also trigger meta registrations from PrimTorch ref
  7250. # decompositions
  7251. import torch._refs
  7252. import torch._refs.nn.functional
  7253. import torch._refs.special
  7254. def activate_meta():
  7255. activate_meta_table = {}
  7256. # For a given op, we pick the most specific decomp function from
  7257. # global_decomp_table in the precedence order of meta > post_autograd > pre_autograd
  7258. for type in ["meta", "post_autograd", "pre_autograd"]:
  7259. registry = global_decomposition_table[type]
  7260. for opo in registry:
  7261. if opo not in activate_meta_table:
  7262. activate_meta_table[opo] = registry[opo]
  7263. for op_overload, fn in activate_meta_table.items():
  7264. # Don't register meta for HigherOrderOp's decomp.
  7265. # We can reconsider this in the future, but in general,
  7266. # the way you do a meta for a HigherOrderOp is different from
  7267. # OpOverload.
  7268. if isinstance(op_overload, torch._ops.HigherOrderOperator):
  7269. continue
  7270. assert isinstance(op_overload, OpOverload)
  7271. op_overload.py_impl(torch._C.DispatchKey.Meta)(fn)
  7272. if torch._C._dispatch_has_kernel_for_dispatch_key(
  7273. op_overload.name(), "CompositeImplicitAutograd"
  7274. ):
  7275. # Internally, we shouldn't be registering meta kernels for any operators that
  7276. # have CompositeImplicitAutograd kernels.
  7277. # Instead, we should be letting those decompositions run, and writing meta kernels
  7278. # only for the base operators.
  7279. if op_overload in global_decomposition_table["meta"]:
  7280. raise RuntimeError(
  7281. f"{op_overload} is a CompositeImplicitAutograd op, we shouldn't "
  7282. "register meta function for it. Instead, we should let the decomposition run and write "
  7283. "meta kernels for the base operators."
  7284. )
  7285. elif op_overload.is_view:
  7286. # Attempting to register a python meta kernel for a view operator.
  7287. # We shouldn't do this, because the output will report as not having aliased storages.
  7288. # All view ops have meta kernels in C++ today, so we should use those instead.
  7289. pass
  7290. elif (
  7291. op_overload.name()
  7292. in {
  7293. "aten::empty_strided", # causing infinite recursion, test_meta.py
  7294. "aten::clone", # causing infinite recursion
  7295. "aten::_to_copy", # causing infinite recursion, test_serialization.py -k test_tensor_subclass_getstate_overwrite # noqa: B950
  7296. "aten::copy_", # Exception not raised, test_torch.py -k test_storage_meta_errors_cpu_int64 # noqa: B950
  7297. "aten::constant_pad_nd", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_amp_istft_cuda_float32 # noqa: B950
  7298. "aten::rot90", # requires_grad mismatch! test_ops.py -k test_fake_crossref_backward_amp_rot90_cuda_float32 # noqa: B950
  7299. "aten::as_strided_scatter", # requires_grad mismatch, test_ops.py -k test_fake_crossref_backward_no_amp_as_strided_scatter_cuda_float32 # noqa: B950
  7300. }
  7301. ):
  7302. pass
  7303. else:
  7304. if "mkldnn::" in op_overload.name():
  7305. _meta_lib_dont_use_me_use_register_meta_for_mkldnn.impl(op_overload, fn)
  7306. elif "mkl::" in op_overload.name():
  7307. _meta_lib_dont_use_me_use_register_meta_for_mkl.impl(op_overload, fn)
  7308. elif "onednn::" in op_overload.name():
  7309. _meta_lib_dont_use_me_use_register_meta_for_onednn.impl(op_overload, fn)
  7310. elif "quantized::" in op_overload.name():
  7311. _meta_lib_dont_use_me_use_register_meta_for_quantized.impl(
  7312. op_overload, fn
  7313. )
  7314. else:
  7315. _meta_lib_dont_use_me_use_register_meta.impl(op_overload, fn)
  7316. activate_meta()