ir.py 329 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997699869997000700170027003700470057006700770087009701070117012701370147015701670177018701970207021702270237024702570267027702870297030703170327033703470357036703770387039704070417042704370447045704670477048704970507051705270537054705570567057705870597060706170627063706470657066706770687069707070717072707370747075707670777078707970807081708270837084708570867087708870897090709170927093709470957096709770987099710071017102710371047105710671077108710971107111711271137114711571167117711871197120712171227123712471257126712771287129713071317132713371347135713671377138713971407141714271437144714571467147714871497150715171527153715471557156715771587159716071617162716371647165716671677168716971707171717271737174717571767177717871797180718171827183718471857186718771887189719071917192719371947195719671977198719972007201720272037204720572067207720872097210721172127213721472157216721772187219722072217222722372247225722672277228722972307231723272337234723572367237723872397240724172427243724472457246724772487249725072517252725372547255725672577258725972607261726272637264726572667267726872697270727172727273727472757276727772787279728072817282728372847285728672877288728972907291729272937294729572967297729872997300730173027303730473057306730773087309731073117312731373147315731673177318731973207321732273237324732573267327732873297330733173327333733473357336733773387339734073417342734373447345734673477348734973507351735273537354735573567357735873597360736173627363736473657366736773687369737073717372737373747375737673777378737973807381738273837384738573867387738873897390739173927393739473957396739773987399740074017402740374047405740674077408740974107411741274137414741574167417741874197420742174227423742474257426742774287429743074317432743374347435743674377438743974407441744274437444744574467447744874497450745174527453745474557456745774587459746074617462746374647465746674677468746974707471747274737474747574767477747874797480748174827483748474857486748774887489749074917492749374947495749674977498749975007501750275037504750575067507750875097510751175127513751475157516751775187519752075217522752375247525752675277528752975307531753275337534753575367537753875397540754175427543754475457546754775487549755075517552755375547555755675577558755975607561756275637564756575667567756875697570757175727573757475757576757775787579758075817582758375847585758675877588758975907591759275937594759575967597759875997600760176027603760476057606760776087609761076117612761376147615761676177618761976207621762276237624762576267627762876297630763176327633763476357636763776387639764076417642764376447645764676477648764976507651765276537654765576567657765876597660766176627663766476657666766776687669767076717672767376747675767676777678767976807681768276837684768576867687768876897690769176927693769476957696769776987699770077017702770377047705770677077708770977107711771277137714771577167717771877197720772177227723772477257726772777287729773077317732773377347735773677377738773977407741774277437744774577467747774877497750775177527753775477557756775777587759776077617762776377647765776677677768776977707771777277737774777577767777777877797780778177827783778477857786778777887789779077917792779377947795779677977798779978007801780278037804780578067807780878097810781178127813781478157816781778187819782078217822782378247825782678277828782978307831783278337834783578367837783878397840784178427843784478457846784778487849785078517852785378547855785678577858785978607861786278637864786578667867786878697870787178727873787478757876787778787879788078817882788378847885788678877888788978907891789278937894789578967897789878997900790179027903790479057906790779087909791079117912791379147915791679177918791979207921792279237924792579267927792879297930793179327933793479357936793779387939794079417942794379447945794679477948794979507951795279537954795579567957795879597960796179627963796479657966796779687969797079717972797379747975797679777978797979807981798279837984798579867987798879897990799179927993799479957996799779987999800080018002800380048005800680078008800980108011801280138014801580168017801880198020802180228023802480258026802780288029803080318032803380348035803680378038803980408041804280438044804580468047804880498050805180528053805480558056805780588059806080618062806380648065806680678068806980708071807280738074807580768077807880798080808180828083808480858086808780888089809080918092809380948095809680978098809981008101810281038104810581068107810881098110811181128113811481158116811781188119812081218122812381248125812681278128812981308131813281338134813581368137813881398140814181428143814481458146814781488149815081518152815381548155815681578158815981608161816281638164816581668167816881698170817181728173817481758176817781788179818081818182818381848185818681878188818981908191819281938194819581968197819881998200820182028203820482058206820782088209821082118212821382148215821682178218821982208221822282238224822582268227822882298230823182328233823482358236823782388239824082418242824382448245824682478248824982508251825282538254825582568257825882598260826182628263826482658266826782688269827082718272827382748275827682778278827982808281828282838284828582868287828882898290829182928293829482958296829782988299830083018302830383048305830683078308830983108311831283138314831583168317831883198320832183228323832483258326832783288329833083318332833383348335833683378338833983408341834283438344834583468347834883498350835183528353835483558356835783588359836083618362836383648365836683678368836983708371837283738374837583768377837883798380838183828383838483858386838783888389839083918392839383948395839683978398839984008401840284038404840584068407840884098410841184128413841484158416841784188419842084218422842384248425842684278428842984308431843284338434843584368437843884398440844184428443844484458446844784488449845084518452845384548455845684578458845984608461846284638464846584668467846884698470847184728473847484758476847784788479848084818482848384848485848684878488848984908491849284938494849584968497849884998500850185028503850485058506850785088509851085118512851385148515851685178518851985208521852285238524852585268527852885298530853185328533853485358536853785388539854085418542854385448545854685478548854985508551855285538554855585568557855885598560856185628563856485658566856785688569857085718572857385748575857685778578857985808581858285838584858585868587858885898590859185928593859485958596859785988599860086018602860386048605860686078608860986108611861286138614861586168617861886198620862186228623862486258626862786288629863086318632863386348635863686378638863986408641864286438644864586468647864886498650865186528653865486558656865786588659866086618662866386648665866686678668866986708671867286738674867586768677867886798680868186828683868486858686868786888689869086918692869386948695869686978698869987008701870287038704870587068707870887098710871187128713871487158716871787188719872087218722872387248725872687278728872987308731873287338734873587368737873887398740874187428743874487458746874787488749875087518752875387548755875687578758875987608761876287638764876587668767876887698770877187728773877487758776877787788779878087818782878387848785878687878788878987908791879287938794879587968797879887998800880188028803880488058806880788088809881088118812881388148815881688178818881988208821882288238824882588268827882888298830883188328833883488358836883788388839884088418842884388448845884688478848884988508851885288538854885588568857885888598860886188628863886488658866886788688869887088718872887388748875887688778878887988808881888288838884888588868887888888898890889188928893889488958896889788988899890089018902890389048905890689078908890989108911891289138914891589168917891889198920892189228923892489258926892789288929893089318932893389348935893689378938893989408941894289438944894589468947894889498950895189528953895489558956895789588959896089618962896389648965896689678968896989708971897289738974897589768977897889798980898189828983898489858986898789888989899089918992899389948995899689978998899990009001900290039004900590069007900890099010901190129013901490159016901790189019902090219022902390249025902690279028902990309031903290339034903590369037903890399040904190429043904490459046904790489049905090519052905390549055905690579058905990609061906290639064906590669067906890699070907190729073907490759076907790789079908090819082908390849085908690879088908990909091909290939094909590969097909890999100910191029103910491059106910791089109911091119112911391149115911691179118911991209121912291239124912591269127912891299130913191329133913491359136913791389139914091419142914391449145914691479148914991509151915291539154915591569157915891599160916191629163916491659166916791689169917091719172917391749175917691779178917991809181918291839184918591869187918891899190919191929193919491959196919791989199920092019202920392049205920692079208920992109211921292139214921592169217921892199220922192229223922492259226922792289229923092319232923392349235923692379238923992409241924292439244924592469247924892499250925192529253925492559256925792589259926092619262926392649265926692679268926992709271927292739274927592769277927892799280928192829283928492859286928792889289929092919292929392949295929692979298929993009301930293039304930593069307930893099310931193129313931493159316931793189319932093219322932393249325932693279328932993309331933293339334933593369337933893399340934193429343934493459346934793489349935093519352935393549355935693579358935993609361936293639364936593669367936893699370937193729373937493759376937793789379938093819382938393849385938693879388938993909391939293939394939593969397939893999400940194029403940494059406940794089409
  1. from __future__ import annotations
  2. import contextlib
  3. import dataclasses
  4. import functools
  5. import itertools
  6. import logging
  7. import operator
  8. import os
  9. import textwrap
  10. import traceback
  11. from collections.abc import Container, Generator, Iterable, Iterator, Sequence
  12. from contextlib import AbstractContextManager, nullcontext
  13. from enum import Enum
  14. from functools import partial
  15. from typing import (
  16. Any,
  17. Callable,
  18. cast,
  19. ClassVar,
  20. Literal,
  21. Optional,
  22. overload,
  23. SupportsFloat,
  24. SupportsInt,
  25. TYPE_CHECKING,
  26. TypeVar,
  27. Union,
  28. )
  29. from typing_extensions import (
  30. assert_never,
  31. Never,
  32. override,
  33. ParamSpec,
  34. Self,
  35. TypeAlias,
  36. TypeIs,
  37. )
  38. from unittest.mock import patch
  39. import sympy
  40. from sympy import Expr, Integer, Symbol
  41. import torch._export.serde.schema as export_schema
  42. import torch._library.utils as library_utils
  43. import torch._logging
  44. import torch.fx
  45. import torch.utils._pytree as pytree
  46. from torch._dynamo.utils import identity
  47. from torch._export.serde.serialize import GraphModuleSerializer
  48. from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
  49. from torch._inductor import metrics
  50. from torch._inductor.utils import get_free_symbols
  51. from torch._prims_common import (
  52. compute_required_storage_length,
  53. is_boolean_dtype,
  54. is_float_dtype,
  55. make_channels_last_strides_for,
  56. StrideType,
  57. )
  58. from torch._subclasses.fake_tensor import get_schema_info
  59. from torch.fx.experimental.symbolic_shapes import (
  60. _remove_effect_token_unbacked_bindings,
  61. compute_unbacked_bindings,
  62. free_symbols,
  63. free_unbacked_symbols,
  64. IterateExprs,
  65. rebind_unbacked,
  66. resolve_unbacked_bindings,
  67. ShapeEnv,
  68. SymTypes,
  69. )
  70. from torch.fx.node import Node
  71. from torch.utils._ordered_set import OrderedSet
  72. from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing
  73. from torch.utils._sympy.symbol import SymT
  74. from . import config, dependencies
  75. from .codegen.common import (
  76. BackendFeature,
  77. CodegenSymbol,
  78. get_scheduling_for_device,
  79. index_prevent_reordering,
  80. Kernel,
  81. )
  82. from .dependencies import (
  83. Dep,
  84. extract_free_symbols,
  85. extract_input_node_reduction_ranges,
  86. extract_read_writes,
  87. var_builder,
  88. )
  89. from .loop_body import LoopBody
  90. from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode
  91. from .runtime.benchmarking import benchmarker
  92. from .runtime.hints import DeviceProperties, ReductionHint
  93. from .utils import (
  94. argsort,
  95. argsort_sym,
  96. cache_on_self,
  97. cache_on_self_and_args,
  98. ceildiv,
  99. convert_shape_to_inductor,
  100. convert_shape_to_symint,
  101. developer_warning,
  102. do_bench_using_profiling,
  103. dtype_from_size,
  104. get_dtype_size,
  105. get_kernel_metadata,
  106. GPU_ALIGN_BYTES,
  107. ir_dataclass,
  108. is_dynamic,
  109. is_gpu,
  110. sympy_dot,
  111. sympy_index_symbol,
  112. sympy_index_symbol_with_prefix,
  113. sympy_product,
  114. sympy_subs,
  115. tensor_is_aligned,
  116. )
  117. from .virtualized import ops, OpsValue, V
  118. if TYPE_CHECKING:
  119. from torch._library.fake_class_registry import FakeScriptObject
  120. from torch.fx.experimental.symbolic_shapes import SympyBoolean
  121. from torch.fx.node import Argument
  122. from .codegen.cuda.cuda_template import CUDATemplate
  123. from .codegen.wrapper import PythonWrapperCodegen
  124. from .graph import GraphLowering
  125. from .utils import IndentedBuffer
  126. else:
  127. CUDATemplate: TypeAlias = object
  128. try:
  129. import triton
  130. triton_version = triton.__version__
  131. has_triton = True
  132. except ImportError:
  133. triton_version = None
  134. has_triton = False
  135. _P = ParamSpec("_P")
  136. _T = TypeVar("_T")
  137. _U = TypeVar("_U")
  138. _V = TypeVar("_V")
  139. _IntLike: TypeAlias = Union[int, Expr]
  140. _NumLike: TypeAlias = Union[int, float, Expr]
  141. _OpOverloads: TypeAlias = Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
  142. log = logging.getLogger(__name__)
  143. indent = functools.partial(textwrap.indent, prefix=" ")
  144. aten = torch.ops.aten
  145. autotune_warmup = int(os.getenv("TORCH_AUTOTUNE_WARMUP", 25))
  146. autotune_rep = int(os.getenv("TORCH_AUTOTUNE_REP", 100))
  147. """ [Note: Inductor IR]
  148. Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
  149. lowering is registered to a particular aten operator, and expects inputs that
  150. correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
  151. expect Inductor TensorBox inputs.
  152. TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
  153. storage, and sometimes views of another Tensor's storage. Mutating tensor operations
  154. (such as add_()) affect the underlying storage and any associated views. Other operations
  155. (such as .t_()) update metadata about the current view but don't modify the underlying storage.
  156. To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
  157. TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
  158. output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
  159. reference View IR or directly reference StorageBox IRs.
  160. Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
  161. may take an existing TensorBox and point it to a new underlying View IR.
  162. Tensors that directly own storage are represented as a chain of:
  163. TensorBox -> StorageBox -> Buffer
  164. where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
  165. If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
  166. (leaving the old buffer unmodified and functionalizing the operation).
  167. Tensors backed by views add one more indirection to the IR.
  168. TensorBox -> View -> StorageBox -> Buffer
  169. In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
  170. Computation is represented by Operation nodes, with each operation producing 1
  171. or more output Buffers. In the case of mutations, these will be new Buffers that have the
  172. mutated buffer listed in its get_mutation_names().
  173. It is also possible to have an InputBuffer for which there is no corresponding Operation,
  174. e.g. it may be a graph input or compile time constant.
  175. """
  176. _NodeOrNodes: TypeAlias = Union[
  177. int,
  178. "TensorBox",
  179. dict[str, "TensorBox"],
  180. "Symbol",
  181. "IRNode",
  182. Sequence[
  183. Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
  184. ],
  185. ]
  186. def _is_static(x: object) -> bool:
  187. return isinstance(x, (int, Integer))
  188. @dataclasses.dataclass(frozen=True)
  189. class GraphPartitionSignature:
  190. # symbol inputs that are necessary for codegen
  191. symbol_inputs: OrderedSet[sympy.Symbol]
  192. # mapping from partition input name to IRNode or Expr. Need the name str since
  193. # we cannot get name from Expr.
  194. input_nodes: dict[str, Union[IRNode, sympy.Expr, TorchBindObject]]
  195. output_nodes: list[IRNode]
  196. # mapping from partition input name to a boolean for whether deallocating it
  197. # in the partition function
  198. input_deallocation: dict[str, bool]
  199. skip_cudagraph: bool
  200. # name of constants read/written by the graph partition
  201. constant_names: list[str]
  202. def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
  203. def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None:
  204. # Could expand this to check deeper properties
  205. # (e.g. TensorBox points to View or StorageBox)
  206. if nodes is None:
  207. pass
  208. elif isinstance(nodes, (list, tuple)):
  209. for node in nodes:
  210. _check_tensorbox(node)
  211. elif isinstance(nodes, dict):
  212. for node in nodes.values():
  213. _check_tensorbox(node)
  214. else:
  215. assert isinstance(
  216. nodes,
  217. (
  218. ExpandView,
  219. DynamicScalar,
  220. AssertScalar,
  221. TensorBox,
  222. sympy.logic.boolalg.Boolean,
  223. Expr,
  224. int,
  225. EffectfulKernel,
  226. ShapeAsConstantBuffer,
  227. ),
  228. ), (
  229. f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
  230. )
  231. # Be picky about the accepted data structure (don't use pytree here)
  232. _check_tensorbox(node_or_nodes)
  233. def ops_wrapper(name: str) -> Callable[..., OpsValue]:
  234. assert isinstance(name, str), type(name)
  235. def fn(*args: object, **kwargs: object) -> OpsValue:
  236. return getattr(ops, name)(*args, **kwargs)
  237. return fn
  238. def inverse_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]:
  239. inv_order = dict(zip(order, range(len(order))))
  240. def reindex(index: Sequence[_T]) -> Sequence[_T]:
  241. assert len(index) == len(inv_order)
  242. return [index[inv_order[i]] for i in range(len(index))]
  243. return reindex
  244. def same_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]:
  245. def reindex(index: Sequence[_T]) -> Sequence[_T]:
  246. assert len(index) == len(order)
  247. return [index[order[i]] for i in range(len(index))]
  248. return reindex
  249. def fuse_reindexing(
  250. reindex1: Callable[[Sequence[_U]], Sequence[_V]],
  251. reindex2: Callable[[Sequence[_T]], Sequence[_U]],
  252. ) -> Callable[[Sequence[_T]], Sequence[_V]]:
  253. def reindex(index: Sequence[_T]) -> Sequence[_V]:
  254. return reindex1(reindex2(index))
  255. return reindex
  256. NHWC_STRIDE_ORDER = [3, 0, 2, 1]
  257. NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
  258. def get_fill_order(
  259. seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None
  260. ) -> Sequence[int]:
  261. """
  262. Convert strides to fill order (argsort)
  263. """
  264. if shape_env is None or all(isinstance(s, (int, sympy.Integer)) for s in seq):
  265. sorted_idx: Sequence[int] = argsort(seq)
  266. else:
  267. # argsort_sym handles unbacked symints (with the help of the shape_env)
  268. sorted_idx = argsort_sym(shape_env, seq)
  269. return sorted_idx
  270. def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[int]:
  271. """
  272. Convert stride order to fill order
  273. For channel last format,
  274. stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
  275. """
  276. lookup = {pos: idx for idx, pos in enumerate(order)}
  277. fill_order = [lookup[i] for i in range(len(order))]
  278. return fill_order
  279. def get_stride_order(
  280. seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None
  281. ) -> Sequence[int]:
  282. """
  283. Convert strides to stride order
  284. """
  285. sorted_idx: Sequence[int] = get_fill_order(seq, shape_env)
  286. out = [0 for _ in range(len(seq))]
  287. for i, elem in enumerate(sorted_idx):
  288. out[elem] = i
  289. return out
  290. @overload
  291. def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ...
  292. @overload
  293. def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ...
  294. def ir_node_to_tensor(
  295. x: Optional[IRNode], guard_shape: bool = True
  296. ) -> Optional[torch.Tensor]:
  297. if x is None:
  298. return None
  299. shape_fn: Callable[[Union[int, Expr]], Union[int, Expr]]
  300. if not guard_shape:
  301. shape_fn = V.graph.sizevars.size_hint
  302. else:
  303. shape_fn = identity
  304. size = [shape_fn(s) for s in x.get_size()]
  305. stride: StrideType
  306. if is_storage_and_layout(x):
  307. stride = [shape_fn(s) for s in x.get_layout().stride]
  308. else:
  309. stride = FlexibleLayout.contiguous_strides(size)
  310. dtype = x.get_dtype()
  311. device = x.get_device()
  312. size = convert_shape_to_symint(size)
  313. stride = convert_shape_to_symint(stride)
  314. with V.graph.sizevars.shape_env.suppress_guards():
  315. t = torch.empty_strided(
  316. size=size, stride=stride, dtype=dtype, device=device
  317. ).zero_()
  318. return t
  319. def may_convert_to_optional(
  320. value: Optional[Sequence[_T]],
  321. ) -> Optional[Sequence[Optional[_T]]]:
  322. if isinstance(value, list) and not value:
  323. # [None] makes sure the cpp wrapper codegen will generate something like
  324. # {std::nullopt} instead of {}
  325. return [None]
  326. return value
  327. def get_device_type(
  328. x: Union[IRNode, OutputSpec, torch.device, None, str],
  329. ) -> Optional[str]:
  330. if isinstance(x, str) or x is None:
  331. return x
  332. elif isinstance(x, torch.device):
  333. return x.type
  334. elif isinstance(x, (IRNode, OutputSpec)):
  335. return get_device_type(x.get_device())
  336. assert_never(f"get_device_type({x}: {type(x).__name__})")
  337. def is_triton(x: Union[IRNode, torch.device, None, str]) -> bool:
  338. device = get_device_type(x)
  339. # Special case cpu and cuda as using the method below
  340. # to determine if the scheduler is a triton scheduler subclass
  341. # requires instantiating a scheduler for them
  342. if device in ["cpu", "cuda"]:
  343. if getattr(config, f"{device}_backend") == "triton":
  344. return True
  345. return False
  346. if (
  347. device is None
  348. or (device_scheduling := get_scheduling_for_device(device)) is None
  349. ):
  350. return False
  351. from .codegen.triton import TritonScheduling
  352. assert isinstance(device_scheduling, type), type(device_scheduling)
  353. return issubclass(device_scheduling, TritonScheduling)
  354. def is_cpu(x: Union[IRNode, torch.device, None, str]) -> bool:
  355. return get_device_type(x) == "cpu"
  356. def is_aligned_realized_tensor_hint(
  357. x: Union[Buffer, TensorBox], alignment: int
  358. ) -> bool:
  359. # Use this as a hint. This won't guard since size_hint doesn't guard.
  360. if (
  361. not isinstance(x, IRNode)
  362. or x.maybe_get_stride() is None
  363. or free_unbacked_symbols(x.get_stride())
  364. or free_unbacked_symbols(x.get_size())
  365. ):
  366. return False
  367. aligned_strides = all(
  368. (V.graph.sizevars.size_hint_or_throw(x.get_stride()[i]) % alignment) == 0
  369. for i in range(len(x.get_stride()) - 1)
  370. )
  371. # if the last dim size is <= 1, stride doesn't matter
  372. aligned_last_dim = (
  373. V.graph.sizevars.size_hint_or_throw(x.get_stride()[-1]) == 1
  374. or V.graph.sizevars.size_hint_or_throw(x.get_size()[-1]) <= 1
  375. )
  376. return aligned_last_dim and aligned_strides
  377. def significant_strides_equal(
  378. strides1: Sequence[_IntLike],
  379. strides2: Sequence[_IntLike],
  380. shape: Sequence[_IntLike],
  381. ) -> bool:
  382. """
  383. Returns true if the strides are equal, ignoring dimensions of size 1 .
  384. """
  385. assert len(shape) == len(strides1) and len(strides1) == len(strides2)
  386. for dim, s1, s2 in zip(shape, strides1, strides2):
  387. if V.graph.sizevars.statically_known_leq(dim, 1):
  388. continue
  389. if not V.graph.sizevars.statically_known_equals(
  390. s1, s2
  391. ) and not V.graph.sizevars.symbolic_hint(s1) == V.graph.sizevars.symbolic_hint(
  392. s2
  393. ):
  394. return False
  395. return True
  396. def try_match_insignificant_strides(
  397. tensor: IRNode,
  398. strides: Sequence[Union[int, torch.SymInt]],
  399. ) -> IRNode:
  400. """
  401. Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
  402. dimensions - size 0 or 1 - will be updated.
  403. If there are real stride differences (NHWC vs NCHW), or the tensor is not realized, then the input will be returned
  404. """
  405. if not is_storage_and_layout(tensor):
  406. return tensor
  407. if all(
  408. V.graph.sizevars.statically_known_equals(s1, s2)
  409. for s1, s2 in zip(strides, tensor.get_stride())
  410. ):
  411. return tensor
  412. if not significant_strides_equal(strides, tensor.get_stride(), tensor.get_size()):
  413. return tensor
  414. storage, old_layout = as_storage_and_layout(tensor)
  415. new_stride = [*old_layout.stride]
  416. for i, s in enumerate(tensor.get_size()):
  417. if V.graph.sizevars.statically_known_leq(s, 1):
  418. new_stride[i] = strides[i]
  419. new_layout = FixedLayout(
  420. old_layout.device,
  421. old_layout.dtype,
  422. old_layout.size,
  423. new_stride,
  424. old_layout.offset,
  425. old_layout.is_pinned,
  426. )
  427. return TensorBox(ReinterpretView(data=storage, layout=new_layout))
  428. def gm_original_output_strides(gm: torch.fx.GraphModule) -> None:
  429. output_node = gm.graph.find_nodes(op="output")[0]
  430. output_node.meta["user_visible_output_idxs"] = [
  431. idx for idx, _ in enumerate(output_node.args)
  432. ]
  433. from torch._inductor.compile_fx import record_original_output_strides
  434. record_original_output_strides(gm)
  435. def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]:
  436. sym_vars: OrderedSet[Expr] = OrderedSet()
  437. for inp in inputs:
  438. sym_vars |= get_free_symbols(inp.get_size(), unbacked_only=False)
  439. sym_vars |= get_free_symbols(inp.get_stride(), unbacked_only=False)
  440. return list(sym_vars)
  441. class IRNode:
  442. """Base class for all intermediate representation (IR) nodes in TorchInductor.
  443. Note:
  444. This is an abstract base class. Most methods raise NotImplementedError
  445. and must be overridden by concrete subclasses.
  446. """
  447. _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()
  448. # NB: These are kinda weird,
  449. origins: OrderedSet[Any] = dataclasses.field(init=False)
  450. # traces back to where the IRNode is created in Inductor
  451. traceback: Optional[list[str]] = dataclasses.field(init=False)
  452. origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
  453. @staticmethod
  454. @contextlib.contextmanager
  455. def current_origins(origins: OrderedSet[Node]) -> Generator[None, None, None]:
  456. old = IRNode._current_origins
  457. IRNode._current_origins = old | origins
  458. try:
  459. yield
  460. finally:
  461. IRNode._current_origins = old
  462. @staticmethod
  463. def is_realized_node(node: IRNode) -> bool:
  464. return isinstance(
  465. node,
  466. (
  467. ComputedBuffer,
  468. InputsKernel,
  469. InputBuffer,
  470. ReinterpretView,
  471. TemplateBuffer,
  472. ),
  473. )
  474. def _post_init_setattr(self, attr: str, value: Any) -> None:
  475. # Intended for use in __post_init__ for enforcing an invariant on a dataclass
  476. # If you must, can also be used for setting provenance info
  477. # We would like to try and minimize these usages though
  478. object.__setattr__(self, attr, value)
  479. def __post_init__(self) -> None:
  480. origins = OrderedSet(self._current_origins)
  481. self._post_init_setattr("origins", origins)
  482. self._post_init_setattr(
  483. "traceback", traceback.format_stack() if config.debug_ir_traceback else None
  484. )
  485. self._post_init_setattr("origin_node", None)
  486. def get_read_names(self) -> OrderedSet[str]:
  487. return OrderedSet(dep.name for dep in self.get_reads())
  488. def get_traceback(self) -> Optional[list[str]]:
  489. return self.traceback
  490. def get_origin_node(self) -> Optional[torch.fx.Node]:
  491. return self.origin_node
  492. def get_defining_op(self) -> Optional[Operation]:
  493. return None
  494. def get_stack_traces(self) -> OrderedSet[str]:
  495. # Return stack traces to user model code
  496. # A single IRNode could correspond to multiple lines of code
  497. stack_traces: OrderedSet[str] = OrderedSet()
  498. origins = self.origins
  499. if isinstance(self, ExternKernel):
  500. origin_node = self.get_origin_node()
  501. if self.origin_node:
  502. origins = OrderedSet([origin_node])
  503. for node in origins:
  504. if hasattr(node, "stack_trace") and node.stack_trace:
  505. # nodes in the backward graph don't have mapping to pre_grad_graph
  506. stack_traces.add(node.stack_trace)
  507. else:
  508. pre_grad_nodes = (
  509. torch._inductor.debug._inductor_post_to_pre_grad_nodes.get(
  510. "postToPre", {}
  511. ).get(node.name, [])
  512. )
  513. if not isinstance(pre_grad_nodes, list):
  514. continue
  515. for node_name in pre_grad_nodes:
  516. stack_trace = (
  517. torch._inductor.debug._inductor_pre_grad_node_stack_trace.get(
  518. node_name, None
  519. )
  520. )
  521. if stack_trace:
  522. stack_traces.add(stack_trace)
  523. return stack_traces
  524. def common_repr(self, shorten: bool = True) -> Sequence[str]:
  525. origins = f"origins={getattr(self, 'origins', '')}"
  526. if shorten and len(origins) > 64:
  527. # this can get *very* long
  528. origins = f"{origins[:61]}..."
  529. if not self.get_stack_traces():
  530. return [origins]
  531. stack_trace_str = []
  532. for stack_trace in self.get_stack_traces():
  533. stack_trace_str.append("stack_traces = {")
  534. stack_trace_str += stack_trace.split("\n")
  535. stack_trace_str.append("}")
  536. return [origins] + stack_trace_str
  537. def str_helper(
  538. self, lines: Sequence[object], shorten: bool = True, multiline: bool = True
  539. ) -> str:
  540. lines = list(lines) + list(self.common_repr(shorten))
  541. lines = list(map(str, lines))
  542. if multiline:
  543. new_lines = indent(",\n".join(lines))
  544. return f"{type(self).__name__}(\n{new_lines}\n)"
  545. else:
  546. return f"{type(self).__name__}({lines})"
  547. def get_dtype(self) -> torch.dtype:
  548. return self.dtype
  549. def maybe_get_dtype(self) -> Optional[torch.dtype]:
  550. try:
  551. return self.get_dtype()
  552. except NotImplementedError:
  553. return None
  554. def get_layout(self) -> Layout:
  555. raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!")
  556. def maybe_get_layout(self) -> Optional[Layout]:
  557. try:
  558. return self.get_layout()
  559. except NotImplementedError:
  560. return None
  561. def get_output_spec(self) -> OutputSpec:
  562. return self.get_layout()
  563. def maybe_get_output_spec(self) -> Optional[OutputSpec]:
  564. try:
  565. return self.get_output_spec()
  566. except NotImplementedError:
  567. return None
  568. def has_tensor_output(self) -> bool:
  569. """True for single tensor output (excludes MultiOutput)"""
  570. return isinstance(self.maybe_get_output_spec(), Layout)
  571. def get_size(self) -> Sequence[Expr]:
  572. raise NotImplementedError(f"get_size() is not implemented by {type(self)}!")
  573. def maybe_get_size(self) -> Optional[Sequence[_IntLike]]:
  574. try:
  575. return self.get_size()
  576. except NotImplementedError:
  577. return None
  578. @property
  579. def shape(self) -> Union[_IntLike, sympy.Rel, Sequence[_IntLike]]:
  580. return self.get_size()
  581. def get_numel(self) -> Expr:
  582. return sympy_product(self.get_size())
  583. def is_zero_elements(self) -> bool:
  584. return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
  585. def realize(self) -> Optional[str]:
  586. """
  587. If the IRNode refers to data which has not been materialized (e.g.,
  588. it is a Pointwise/Reduction that could potentially have more
  589. compute fused into it), realize the IRNode into physical memory,
  590. ending the possibility of fusing into it, but allowing, e.g., multiple
  591. users to access the data without having to recompute.
  592. Check StorageBox.realize for a particularly notable implementation.
  593. TODO(ezyang): I think, in principle, every IRNode should have an
  594. implementation of this, and most of the time no-op is OK, but you
  595. really do have to audit each IRNode for this, so for now, raise
  596. an error if it's not implemented. Note that some code in graph.py
  597. will catch this thrown error and suppress it with a warning.
  598. """
  599. raise NotImplementedError(f"realize NYI on {type(self)}")
  600. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  601. raise NotImplementedError(f"codegen_reference NYI on {type(self)}")
  602. def get_device(self) -> Optional[torch.device]:
  603. return None
  604. def get_device_or_error(self) -> torch.device:
  605. device = self.get_device()
  606. assert device is not None
  607. return device
  608. def has_exceeded_max_reads(self) -> bool:
  609. return False
  610. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  611. raise NotImplementedError(type(self).__name__)
  612. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  613. raise NotImplementedError(type(self).__name__)
  614. def get_stride(self) -> Sequence[_IntLike]:
  615. raise NotImplementedError(type(self).__name__)
  616. def maybe_get_stride(self) -> Optional[Sequence[_IntLike]]:
  617. try:
  618. return self.get_stride()
  619. except NotImplementedError:
  620. return None
  621. def get_name(self) -> str:
  622. raise NotImplementedError(type(self).__name__)
  623. def maybe_get_name(self) -> Optional[str]:
  624. try:
  625. return self.get_name()
  626. except NotImplementedError:
  627. return None
  628. def is_input_buffer(self) -> bool:
  629. try:
  630. return self.get_name() in V.graph.graph_inputs
  631. except NotImplementedError:
  632. return False
  633. def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
  634. return False
  635. def mark_reuse(self, users: int) -> None:
  636. pass
  637. def realize_hint(self) -> None:
  638. pass
  639. def unwrap_view(self) -> IRNode:
  640. raise NotImplementedError(type(self).__name__)
  641. def freeze_layout(self) -> None:
  642. raise NotImplementedError(type(self).__name__)
  643. def freeze_layout_with_stride_order(
  644. self, order: Sequence[int], allow_padding: bool = False
  645. ) -> None:
  646. raise NotImplementedError(type(self).__name__)
  647. def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
  648. raise NotImplementedError(type(self).__name__)
  649. def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None:
  650. raise NotImplementedError(type(self).__name__)
  651. def freeze_layout_with_exact_strides(
  652. self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
  653. ) -> None:
  654. raise NotImplementedError(type(self).__name__)
  655. def get_read_writes(self) -> dependencies.ReadWrites:
  656. raise NotImplementedError(type(self).__name__)
  657. def get_reads(self) -> OrderedSet[Dep]:
  658. return self.get_read_writes().reads
  659. def num_reads(self) -> int:
  660. return len(self.get_reads())
  661. def get_storage_numel(self) -> _IntLike:
  662. raise NotImplementedError(type(self).__name__)
  663. def get_free_symbol_uses(
  664. self, unbacked_only: bool = False
  665. ) -> OrderedSet[sympy.Symbol]:
  666. raise NotImplementedError(type(self).__name__)
  667. def get_reduction_type(self) -> Optional[str]:
  668. raise NotImplementedError(type(self).__name__)
  669. def get_reduction_size(self) -> Sequence[Expr]:
  670. raise NotImplementedError(type(self).__name__)
  671. def is_extern(self) -> bool:
  672. return False
  673. def is_no_op(self) -> bool:
  674. return False
  675. def constant_to_device(self, device: torch.device) -> IRNode:
  676. raise NotImplementedError(type(self).__name__)
  677. def get_mutation_names(self) -> Sequence[str]:
  678. raise NotImplementedError(type(self).__name__)
  679. def get_operation_name(self) -> str:
  680. raise NotImplementedError(type(self).__name__)
  681. def get_inputs_that_alias_output(self) -> Sequence[str]:
  682. raise NotImplementedError(type(self).__name__)
  683. if TYPE_CHECKING:
  684. @property
  685. def dtype(self) -> torch.dtype: ...
  686. @ir_dataclass(frozen=False)
  687. class Operation:
  688. def __post_init__(self) -> None:
  689. self.operation_name: Optional[str] = None
  690. def get_device(self) -> Optional[torch.device]:
  691. raise NotImplementedError
  692. def get_origin_node(self) -> Optional[torch.fx.Node]:
  693. assert hasattr(self, "origin_node")
  694. return self.origin_node
  695. def get_origins(self) -> OrderedSet[Any]:
  696. assert hasattr(self, "origins")
  697. return self.origins
  698. def get_operation_name(self) -> str:
  699. assert self.operation_name is not None
  700. return self.operation_name
  701. def is_extern(self) -> bool:
  702. return False
  703. def is_no_op(self) -> bool:
  704. return False
  705. def get_read_writes(self) -> dependencies.ReadWrites:
  706. raise NotImplementedError
  707. def is_user_of(self, name: str) -> bool:
  708. return name in self.get_read_names()
  709. def get_read_names(self) -> OrderedSet[str]:
  710. return OrderedSet(dep.name for dep in self.get_reads())
  711. def get_reads(self) -> OrderedSet[Dep]:
  712. return self.get_read_writes().reads
  713. def get_outputs(self) -> list[Buffer]:
  714. raise NotImplementedError
  715. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  716. return OrderedSet()
  717. def get_free_symbol_uses(
  718. self, unbacked_only: bool = False
  719. ) -> OrderedSet[sympy.Symbol]:
  720. """
  721. When unbacked_only=True:
  722. Returns the unbacked symbols which are required to be in scope in
  723. order to successfully perform codegen for this buffer. For example,
  724. a buffer that corresponds to an extern kernel call that takes i0 as
  725. an argument would return {i0} here. This is used to generate necessary
  726. dependencies that ensure we actually bind i0 in codegen before you
  727. try to use it.
  728. Note that this is NOT transitive; in particular, if this buffer takes
  729. in as input another buffer with dynamic shape (e.g., (i0,)), we will
  730. not report it here, because you will already have a dependency
  731. on that buffer, which will eventually have a dependency on i0 if
  732. necessary.
  733. When unbacked_only=False:
  734. Similar to `unbacked_only=True` but including all free symbols
  735. instead of only free unbacked symbols.
  736. """
  737. return OrderedSet()
  738. def get_workspace_size(self) -> int:
  739. """
  740. Gets extra global memory size needed by this buffer.
  741. Some algorithms (e.g. group gemm) may require extra global memory in the generated code.
  742. """
  743. return 0
  744. @ir_dataclass
  745. class Loops(IRNode):
  746. device: torch.device
  747. dtype: torch.dtype
  748. inner_fn: Callable[..., Any]
  749. ranges: Sequence[_IntLike]
  750. @cache_on_self_and_args("Loops")
  751. def get_free_symbol_uses(
  752. self, unbacked_only: bool = False
  753. ) -> OrderedSet[sympy.Symbol]:
  754. return OrderedSet().union(
  755. *(get_free_symbols(e, unbacked_only) for e in self.ranges),
  756. self.inner_fn_free_symbols(unbacked_only),
  757. )
  758. def _to_str(self, names: Sequence[str]) -> str:
  759. return self.str_helper(
  760. [
  761. f"'{self.device.type}'",
  762. str(self.dtype),
  763. self.inner_fn_str(),
  764. ]
  765. + [f"{name}={getattr(self, name)}" for name in names]
  766. + [f"origin_node={self.origin_node!r}"]
  767. )
  768. def __post_init__(self) -> None:
  769. super().__post_init__()
  770. def __str__(self) -> str:
  771. return self._to_str(("ranges",))
  772. __repr__ = __str__
  773. def get_device(self) -> Optional[torch.device]:
  774. return self.device
  775. def get_origin_node(self) -> Optional[torch.fx.Node]:
  776. return self.origin_node
  777. def get_size(self) -> Sequence[Expr]:
  778. return self.ranges
  779. def get_pointwise_size(self) -> Sequence[Expr]:
  780. return self.ranges
  781. @classmethod
  782. def create(
  783. cls, *args: Any, **kwargs: Any
  784. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  785. origin_node = kwargs.pop("origin_node", None)
  786. tb = kwargs.pop("traceback", None)
  787. r = cls(*args, **kwargs)
  788. # Need to explicitly set origin_node here to propagate it down.
  789. # todo(chilli): I think it would be better for IRNode to directly set
  790. # origin_node
  791. r._post_init_setattr("origin_node", origin_node)
  792. r._post_init_setattr("traceback", tb or r.traceback)
  793. return TensorBox.create(r)
  794. @staticmethod
  795. def _index(ranges: Sequence[_IntLike], prefix: SymT = SymT.INDEX) -> Sequence[Expr]:
  796. return [
  797. sympy.S.Zero if s == 1 else sympy_index_symbol_with_prefix(prefix, n)
  798. for n, s in enumerate(ranges)
  799. ]
  800. @cache_on_self
  801. def inner_fn_opcount(self) -> OpCountResult:
  802. opcounter = OpCounterCSE(V.MockHandler())
  803. with (
  804. V.set_ops_handler(opcounter),
  805. patch.object(FlexibleLayout, "allow_indexing", True),
  806. ):
  807. self.inner_fn(*self.inner_fn_args())
  808. return opcounter.getvalue()
  809. def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]:
  810. return (self._index(self.ranges),)
  811. @cache_on_self
  812. def inner_fn_str(self) -> str:
  813. return V.KernelFormatterHandler.ir_to_string(
  814. self.inner_fn, *self.inner_fn_args()
  815. )
  816. def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
  817. if threshold is None:
  818. threshold = 0
  819. threshold = max(threshold, config.realize_opcount_threshold)
  820. return self.inner_fn_opcount().num_ops > threshold
  821. def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  822. index = self._index(self.ranges)
  823. return extract_free_symbols(self.inner_fn, index, unbacked_only=unbacked_only)
  824. def get_reads(self) -> OrderedSet[Dep]:
  825. with patch.object(FlexibleLayout, "allow_indexing", True):
  826. if self.get_reduction_type():
  827. return extract_read_writes(
  828. self.make_loader(),
  829. self.get_size(),
  830. self.get_reduction_size(),
  831. ).reads
  832. else:
  833. return extract_read_writes(
  834. self.make_loader(),
  835. self.get_size(),
  836. ).reads
  837. def get_read_names(self) -> OrderedSet[str]:
  838. return OrderedSet(self.inner_fn_opcount().read_buffers)
  839. def num_reads(self) -> int:
  840. return len(self.inner_fn_opcount().read_buffers)
  841. def get_reduction_size(self) -> Sequence[Expr]:
  842. raise NotImplementedError(
  843. f"get_reduction_size() is not implemented by {type(self)}!"
  844. )
  845. def get_reduction_type(self) -> Optional[str]:
  846. raise NotImplementedError(
  847. f"get_reduction_type() is not implemented by {type(self)}!"
  848. )
  849. def constant_to_device(self, device: torch.device) -> IRNode:
  850. raise NotImplementedError(
  851. f"constant_to_device() is not implemented by {type(self)}!"
  852. )
  853. def nop_loader_fn(idx: Union[Expr, Sequence[Expr]], *, dtype: torch.dtype) -> OpsValue:
  854. if dtype.is_floating_point:
  855. return ops.constant(float("nan"), dtype)
  856. else:
  857. return ops.constant(0, dtype)
  858. @ir_dataclass
  859. class Pointwise(Loops):
  860. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  861. # Make zero-element loops into a no-op
  862. if self.is_zero_elements():
  863. return partial(nop_loader_fn, dtype=self.dtype)
  864. return self.inner_fn
  865. def get_reduction_size(self) -> Sequence[sympy.Expr]:
  866. return []
  867. def get_reduction_type(self) -> Optional[str]:
  868. return None
  869. def store_output(
  870. self,
  871. output_name: Optional[str],
  872. indexer: Callable[[Sequence[Expr]], Never],
  873. vars: Sequence[Expr],
  874. ) -> None:
  875. loader = self.make_loader()
  876. return ops.store(output_name or "unnamed", indexer(vars), loader(vars))
  877. def constant_to_device(self, device: torch.device) -> IRNode:
  878. """Move this to a given device. Requires that all reads are to constants."""
  879. loader = self.make_loader()
  880. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  881. return Pointwise(
  882. device=device,
  883. dtype=self.dtype,
  884. inner_fn=loader,
  885. ranges=self.ranges,
  886. )
  887. @ir_dataclass
  888. class Scatter(Pointwise):
  889. output_indexer: Callable[[Sequence[Expr]], Expr]
  890. scatter_mode: StoreMode = None
  891. def constant_to_device(self, device: torch.device) -> IRNode:
  892. """Move this to a given device. Requires that all reads are to constants."""
  893. loader = self.make_loader()
  894. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  895. return Scatter(
  896. device=device,
  897. dtype=self.dtype,
  898. inner_fn=loader,
  899. ranges=self.ranges,
  900. output_indexer=self.output_indexer,
  901. scatter_mode=self.scatter_mode,
  902. )
  903. def store_output(
  904. self,
  905. output_name: Optional[str],
  906. indexer: Callable[[Sequence[Expr]], Never],
  907. vars: Sequence[Expr],
  908. ) -> Any:
  909. loader = self.make_loader()
  910. if output_name is None:
  911. output_name = "unnamed"
  912. return ops.store(
  913. output_name,
  914. indexer(self.output_indexer(vars)),
  915. loader(vars),
  916. mode=self.scatter_mode,
  917. )
  918. REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = {
  919. "any": ops_wrapper("logical_or"),
  920. "max": ops_wrapper("maximum"),
  921. "min": ops_wrapper("minimum"),
  922. "prod": ops_wrapper("mul"),
  923. "sum": ops_wrapper("add"),
  924. "xor_sum": ops_wrapper("bitwise_xor"),
  925. }
  926. def get_reduction_combine_fn(
  927. reduction_type: str, dtype: torch.dtype, arg_break_ties_left: bool = True
  928. ) -> Callable[..., object]:
  929. if reduction_type in REDUCTION_COMBINE_FN:
  930. return REDUCTION_COMBINE_FN[reduction_type]
  931. elif reduction_type in ("argmax", "argmin"):
  932. def argmax_combine_fn(
  933. a: tuple[object, object], b: tuple[object, object]
  934. ) -> tuple[OpsValue, OpsValue]:
  935. a_value, a_index = a
  936. b_value, b_index = b
  937. if reduction_type == "argmin":
  938. mask = ops.lt(a_value, b_value)
  939. else:
  940. mask = ops.gt(a_value, b_value)
  941. equal = ops.eq(a_value, b_value)
  942. if is_float_dtype(dtype):
  943. a_isnan = ops.ne(a_value, a_value)
  944. b_isnan = ops.ne(b_value, b_value)
  945. mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan))
  946. equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan))
  947. tie = (
  948. ops.lt(a_index, b_index)
  949. if arg_break_ties_left
  950. else ops.gt(a_index, b_index)
  951. )
  952. mask = ops.logical_or(mask, ops.logical_and(equal, tie))
  953. return (
  954. ops.where(mask, a_value, b_value),
  955. ops.where(mask, a_index, b_index),
  956. )
  957. return argmax_combine_fn
  958. elif reduction_type == "welford_combine":
  959. def welford_combine_fn(
  960. a: tuple[OpsValue, OpsValue, OpsValue],
  961. b: tuple[OpsValue, OpsValue, OpsValue],
  962. ) -> tuple[OpsValue, OpsValue, OpsValue]:
  963. a_mean, a_m2, a_weight = a
  964. b_mean, b_m2, b_weight = b
  965. delta = b_mean - a_mean
  966. new_weight = a_weight + b_weight
  967. w2_over_w = b_weight / new_weight
  968. return (
  969. a_mean + delta * w2_over_w,
  970. a_m2 + b_m2 + delta * delta * a_weight * w2_over_w,
  971. new_weight,
  972. )
  973. return welford_combine_fn
  974. else:
  975. raise NotImplementedError(f"unknown reduction_type={reduction_type}")
  976. @ir_dataclass
  977. class Reduction(Loops):
  978. reduction_ranges: Sequence[_IntLike]
  979. reduction_type: ReductionType
  980. # self.dtype represents the dst dtype
  981. src_dtype: torch.dtype
  982. reduction_hint: ReductionHint
  983. def __str__(self) -> str:
  984. return self._to_str(("ranges", "reduction_ranges", "reduction_type"))
  985. __repr__ = __str__
  986. @cache_on_self_and_args("Reduction")
  987. def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  988. return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
  989. *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
  990. )
  991. def get_reduction_size(self) -> Sequence[Expr]:
  992. return self.reduction_ranges
  993. def get_reduction_type(self) -> Optional[str]:
  994. return self.reduction_type
  995. def store_reduction(
  996. self,
  997. output_name: Optional[str],
  998. indexer: Callable[[Sequence[Expr]], Never],
  999. vars: Sequence[Expr],
  1000. reduction_vars: Sequence[Symbol],
  1001. ) -> None:
  1002. value = ops.reduction(
  1003. self.dtype,
  1004. self.src_dtype,
  1005. self.reduction_type,
  1006. self.inner_fn(vars, reduction_vars),
  1007. )
  1008. ops.store_reduction(output_name or "unnamed", indexer(vars), value)
  1009. def index_length(self) -> int:
  1010. return len(self.ranges) + len(self.reduction_ranges)
  1011. def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
  1012. index = self._index(self.ranges)
  1013. rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
  1014. return (index, rindex)
  1015. def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  1016. index = self._index(self.ranges)
  1017. rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
  1018. return extract_free_symbols(
  1019. self.inner_fn, index, rindex, unbacked_only=unbacked_only
  1020. )
  1021. def constant_to_device(self, device: torch.device) -> IRNode:
  1022. """Move this to a given device. Requires that all reads are to constants."""
  1023. loader = self.make_loader()
  1024. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  1025. return Reduction(
  1026. device=device,
  1027. dtype=self.dtype,
  1028. inner_fn=loader,
  1029. ranges=self.ranges,
  1030. reduction_ranges=self.reduction_ranges,
  1031. reduction_type=self.reduction_type,
  1032. src_dtype=self.src_dtype,
  1033. reduction_hint=ReductionHint.DEFAULT,
  1034. )
  1035. @staticmethod
  1036. def num_splits(
  1037. device: torch.device,
  1038. dst_dtype: torch.dtype,
  1039. src_dtype: torch.dtype,
  1040. inner_fn: Callable[_P, OpsValue],
  1041. ranges: Sequence[_IntLike],
  1042. reduction_ranges: Sequence[_IntLike],
  1043. reduction_type: Union[ReductionType, Literal["scan"]],
  1044. reduction_numel: Expr,
  1045. input_node: Optional[IRNode] = None,
  1046. ) -> tuple[ReductionHint, _IntLike]:
  1047. reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel)
  1048. numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))
  1049. should_split = reduction_type == "scan" or (
  1050. not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
  1051. and reduction_type
  1052. not in (
  1053. "argmax",
  1054. "argmin",
  1055. )
  1056. and config.split_reductions
  1057. )
  1058. if not (_is_static(reduction_numel_hint) and _is_static(numel_hint)):
  1059. # We don't support unbacked symints
  1060. return ReductionHint.DEFAULT, 1
  1061. props = DeviceProperties.create(device)
  1062. num_sm = props.multi_processor_count
  1063. min_elements_per_thread = 32
  1064. if should_split:
  1065. inner_reduction_splits: Callable[[int, int], int] = functools.partial(
  1066. V.choices.reduction_split_factor, device, inner_reduction=True
  1067. )
  1068. outer_reduction_splits: Callable[[int, int], int] = functools.partial(
  1069. V.choices.reduction_split_factor, device, inner_reduction=False
  1070. )
  1071. else:
  1072. def inner_reduction_splits(
  1073. reduction_numel_hint: int,
  1074. numel_hint: int,
  1075. ) -> int:
  1076. return 1
  1077. outer_reduction_splits = inner_reduction_splits
  1078. # easy cases
  1079. if numel_hint == 1:
  1080. split = inner_reduction_splits(reduction_numel_hint, numel_hint)
  1081. if split == 1:
  1082. # No need to split.
  1083. return ReductionHint.INNER, split
  1084. if input_node is not None and isinstance(input_node, TensorBox):
  1085. with patch.object(FlexibleLayout, "allow_indexing", True):
  1086. (
  1087. new_ranges,
  1088. new_reduction_ranges,
  1089. ) = extract_input_node_reduction_ranges(input_node)
  1090. if new_ranges is not None and new_reduction_ranges is not None:
  1091. extracted_numel_hint = V.graph.sizevars.symbolic_hint(
  1092. sympy_product(new_ranges + new_reduction_ranges)
  1093. )
  1094. if reduction_numel_hint == extracted_numel_hint:
  1095. log.debug(
  1096. "Use previous IRNode's range and reduction_ranges instead of split. "
  1097. "current ranges: %s, current reduction ranges: %s, current split: %d, "
  1098. "new ranges: %s, new reduction ranges: %s",
  1099. ranges,
  1100. reduction_ranges,
  1101. split,
  1102. new_ranges,
  1103. new_reduction_ranges,
  1104. )
  1105. # If the input_node or its dependent nodes are also Reduction nodes,
  1106. # use reduction_sizes of this node or its dependent nodes directly.
  1107. return ReductionHint.INNER, -1
  1108. return ReductionHint.INNER, split
  1109. if (
  1110. reduction_numel_hint <= min_elements_per_thread
  1111. or numel_hint >= num_sm * 2 * 32
  1112. ):
  1113. return ReductionHint.DEFAULT, 1
  1114. r = Reduction(
  1115. device=device,
  1116. dtype=dst_dtype,
  1117. inner_fn=inner_fn,
  1118. ranges=ranges,
  1119. reduction_ranges=reduction_ranges,
  1120. reduction_type=reduction_type if reduction_type != "scan" else "sum",
  1121. src_dtype=src_dtype,
  1122. reduction_hint=ReductionHint.DEFAULT,
  1123. )
  1124. def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]:
  1125. device = r.get_device()
  1126. assert device is not None
  1127. cb = ComputedBuffer(
  1128. name=None,
  1129. layout=FlexibleLayout(
  1130. device=device,
  1131. dtype=r.get_dtype(),
  1132. size=r.get_size(),
  1133. ),
  1134. data=r,
  1135. )
  1136. read_writes = cb.get_read_writes()
  1137. # try finding the full size producer
  1138. # TODO this will fail for something like ((1, N) * (N, 1)).sum()
  1139. # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
  1140. assert read_writes.range_vars is not None
  1141. range_vars = [
  1142. r
  1143. for r in read_writes.range_vars
  1144. if isinstance(r, Expr) and not isinstance(r, sympy.Number)
  1145. ]
  1146. indices = []
  1147. changed = False
  1148. for md in sorted(read_writes.reads, key=lambda x: x.name):
  1149. if all(r in md.index.free_symbols for r in range_vars):
  1150. indices.append(md.index)
  1151. if md.name in V.graph.name_to_buffer:
  1152. buf = V.graph.name_to_buffer[md.name]
  1153. original_stride = getattr(buf.layout, "stride", None)
  1154. buf.decide_layout()
  1155. if getattr(buf.layout, "stride", None) != original_stride:
  1156. changed = True
  1157. return indices, changed
  1158. indices, changed = get_read_indices(r)
  1159. if changed:
  1160. indices, _ = get_read_indices(r)
  1161. if len(indices) == 0:
  1162. # TODO determine splits when all inputs are broadcast
  1163. return ReductionHint.DEFAULT, 1
  1164. (_, reduction_vars), ranges1 = dependencies.index_vars_squeeze(
  1165. r.get_size(), r.get_reduction_size()
  1166. )
  1167. num_outer = 0
  1168. num_inner = 0
  1169. for i in indices:
  1170. j = V.graph.sizevars.simplify_with_ranges(i, ranges1)
  1171. strides = V.graph.sizevars.stride_hints(
  1172. j, reduction_vars, list(ranges1.keys())
  1173. )
  1174. outer = all(s > 1 for s in strides)
  1175. if outer:
  1176. num_outer += 1
  1177. else:
  1178. num_inner += 1
  1179. if num_inner > num_outer:
  1180. return ReductionHint.INNER, inner_reduction_splits(
  1181. reduction_numel_hint, numel_hint
  1182. )
  1183. else:
  1184. return ReductionHint.OUTER, outer_reduction_splits(
  1185. reduction_numel_hint, numel_hint
  1186. )
  1187. @staticmethod
  1188. def _unroll_reduction_fn(
  1189. inner_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], OpsValue],
  1190. reduction_ranges: Sequence[_IntLike],
  1191. reduction_type: str,
  1192. src_dtype: torch.dtype,
  1193. ) -> Callable[[Sequence[_IntLike]], OpsValue]:
  1194. """Convert inner_fn from a reduction to an pointwise"""
  1195. reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges)
  1196. combine_fn = get_reduction_combine_fn(reduction_type, src_dtype)
  1197. def fn(index: Sequence[_IntLike]) -> Any:
  1198. return functools.reduce(
  1199. combine_fn,
  1200. (
  1201. value_fn(index, rindex)
  1202. for rindex in itertools.product(
  1203. *[range(x) for x in reduction_ranges]
  1204. )
  1205. ),
  1206. )
  1207. value_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Any]
  1208. if reduction_type in ("argmin", "argmax"):
  1209. flatten_index = _fixed_indexer(
  1210. reduction_ranges,
  1211. FlexibleLayout.contiguous_strides(reduction_ranges),
  1212. )
  1213. def value_fn(
  1214. index: Sequence[_IntLike], rindex: Sequence[_IntLike]
  1215. ) -> tuple[OpsValue, OpsValue]:
  1216. rindex = [sympy.expand(i) for i in rindex]
  1217. return (
  1218. inner_fn(index, rindex),
  1219. ops.index_expr(flatten_index(rindex), torch.int64),
  1220. )
  1221. return lambda index: fn(index)[1]
  1222. else:
  1223. value_fn = inner_fn
  1224. return fn
  1225. @classmethod
  1226. def create(
  1227. cls,
  1228. device: torch.device,
  1229. dst_dtype: torch.dtype,
  1230. src_dtype: torch.dtype,
  1231. inner_fn: Callable[..., Any],
  1232. ranges: Sequence[Expr],
  1233. reduction_ranges: Sequence[Expr],
  1234. reduction_type: ReductionType,
  1235. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  1236. input_node: Optional[IRNode] = None,
  1237. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1238. reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
  1239. if reduction_numel == 0:
  1240. # N.B. This is a hack to generate the literal of the given type
  1241. # Ideally, we should be fixing `def constant` in triton.py
  1242. # but it breaks due to hardcoded dtypes in other places
  1243. def py_cnst(val: object) -> Union[bool, float, int]:
  1244. if dst_dtype == torch.bool:
  1245. return bool(val)
  1246. elif dst_dtype.is_floating_point:
  1247. assert isinstance(val, SupportsFloat), type(val)
  1248. return float(val)
  1249. else:
  1250. assert isinstance(val, SupportsInt), type(val)
  1251. return int(val)
  1252. rtypes_to_inits = {
  1253. "sum": py_cnst(0),
  1254. "xor_sum": py_cnst(0),
  1255. "prod": py_cnst(1),
  1256. "any": py_cnst(0),
  1257. # "all" is desugared to `!any(!val)`
  1258. }
  1259. assert reduction_type in rtypes_to_inits.keys(), (
  1260. f"{reduction_type} not supported for zero-dimension tensors!"
  1261. )
  1262. def const_fn(index: int) -> OpsValue:
  1263. return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
  1264. return Pointwise.create(
  1265. device=device,
  1266. dtype=src_dtype,
  1267. inner_fn=const_fn,
  1268. ranges=list(ranges),
  1269. )
  1270. if reduction_numel == 1:
  1271. # this reduction is actually a pointwise op
  1272. if reduction_type in ("argmin", "argmax"):
  1273. def fn(index: int) -> OpsValue:
  1274. return ops.constant(0, dst_dtype)
  1275. else:
  1276. def fn(index: int) -> OpsValue:
  1277. reduction_index = [sympy.S.Zero for _ in reduction_ranges]
  1278. return inner_fn(index, reduction_index)
  1279. return Pointwise.create(
  1280. device=device, dtype=dst_dtype, inner_fn=fn, ranges=ranges
  1281. )
  1282. if (
  1283. isinstance(reduction_numel, Integer)
  1284. and V.graph.sizevars.size_hint_or_throw(reduction_numel)
  1285. < config.unroll_reductions_threshold
  1286. and (sympy_product(ranges) != 1 or is_gpu(device.type))
  1287. ):
  1288. # NB: This works around https://github.com/pytorch/pytorch/issues/140457
  1289. # since turning reductions into pointwise ops can exacerbate this problem
  1290. return Pointwise.create(
  1291. device=device,
  1292. dtype=dst_dtype,
  1293. inner_fn=cls._unroll_reduction_fn(
  1294. inner_fn, reduction_ranges, reduction_type, src_dtype
  1295. ),
  1296. ranges=ranges,
  1297. )
  1298. # triton doesn't support reduce to single element well, so break it up
  1299. hint, split = cls.num_splits(
  1300. device,
  1301. dst_dtype,
  1302. src_dtype,
  1303. inner_fn,
  1304. ranges,
  1305. reduction_ranges,
  1306. reduction_type,
  1307. reduction_numel,
  1308. input_node,
  1309. )
  1310. def _maybe_increase_split(split: int) -> int:
  1311. # don't apply min_num_split constraint for static shape case.
  1312. if _is_static(reduction_numel):
  1313. return split
  1314. if split > 1:
  1315. return max(split, config.min_num_split)
  1316. else:
  1317. return split
  1318. split = _maybe_increase_split(split)
  1319. # intermediate reduction in split can contain complex indexing,
  1320. # and num_splits will fail to correctly set the hint
  1321. # reuse the passed hint if available
  1322. if reduction_hint == ReductionHint.DEFAULT:
  1323. reduction_hint = hint
  1324. if split == -1:
  1325. assert input_node is not None
  1326. with patch.object(FlexibleLayout, "allow_indexing", True):
  1327. new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
  1328. input_node
  1329. )
  1330. assert new_ranges is not None
  1331. assert new_reduction_ranges is not None
  1332. return cls.create_multilayer_existing_ranges(
  1333. device,
  1334. dst_dtype,
  1335. src_dtype,
  1336. inner_fn,
  1337. ranges,
  1338. reduction_ranges,
  1339. new_ranges,
  1340. new_reduction_ranges,
  1341. reduction_type,
  1342. reduction_hint,
  1343. )
  1344. elif split > 1:
  1345. # triton doesn't support reduce to single element well, so break it up
  1346. return cls.create_multilayer(
  1347. device,
  1348. dst_dtype,
  1349. src_dtype,
  1350. inner_fn,
  1351. ranges,
  1352. reduction_ranges,
  1353. reduction_type,
  1354. split,
  1355. reduction_hint,
  1356. input_node,
  1357. )
  1358. return TensorBox.create(
  1359. Reduction(
  1360. device=device,
  1361. dtype=dst_dtype,
  1362. inner_fn=inner_fn,
  1363. ranges=ranges,
  1364. reduction_ranges=reduction_ranges,
  1365. reduction_type=reduction_type,
  1366. src_dtype=src_dtype,
  1367. reduction_hint=reduction_hint,
  1368. )
  1369. )
  1370. @staticmethod
  1371. def default_accumulator(
  1372. reduction_type: str, dtype: torch.dtype
  1373. ) -> Union[_NumLike, Sequence[_NumLike]]:
  1374. if reduction_type in ("max", "argmax"):
  1375. if is_float_dtype(dtype):
  1376. return float("-inf")
  1377. elif is_boolean_dtype(dtype):
  1378. return False
  1379. else:
  1380. return torch.iinfo(dtype).min
  1381. if reduction_type in ("min", "argmin"):
  1382. if is_float_dtype(dtype):
  1383. return float("inf")
  1384. elif is_boolean_dtype(dtype):
  1385. return True
  1386. else:
  1387. return torch.iinfo(dtype).max
  1388. zero = False if is_boolean_dtype(dtype) else 0
  1389. one = True if is_boolean_dtype(dtype) else 1
  1390. return {
  1391. "sum": zero,
  1392. "prod": one,
  1393. "xor_sum": zero,
  1394. "any": zero,
  1395. "welford_reduce": (zero, zero, zero),
  1396. "welford_combine": (zero, zero, zero),
  1397. "online_softmax_reduce": (float("-inf"), zero),
  1398. }[reduction_type]
  1399. @staticmethod
  1400. def default_value(
  1401. reduction_type: str, dtype: torch.dtype
  1402. ) -> Union[_NumLike, Sequence[_NumLike]]:
  1403. if reduction_type == "welford_reduce":
  1404. return 0
  1405. return Reduction.default_accumulator(reduction_type, dtype)
  1406. @staticmethod
  1407. def _multilayer_second_step_hint(
  1408. split: _IntLike, numel_hint: int, reduction_hint: ReductionHint
  1409. ) -> ReductionHint:
  1410. if split == -1:
  1411. return reduction_hint
  1412. if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
  1413. return ReductionHint.OUTER_TINY
  1414. if (
  1415. split <= 1024
  1416. and numel_hint <= 256
  1417. and reduction_hint == ReductionHint.OUTER
  1418. ):
  1419. return ReductionHint.OUTER_TINY
  1420. return reduction_hint
  1421. @classmethod
  1422. def check_for_split_dense_dim_reindexing(
  1423. cls, reduction_numel: _IntLike, input_node: Optional[IRNode]
  1424. ) -> Optional[int]:
  1425. """
  1426. If we are reducing over the full tensor, and it is non-dense in the last dimension,
  1427. reindex so we reduce over the dense dimension. initially just handle complete
  1428. reduction case
  1429. """
  1430. if input_node is None:
  1431. return None
  1432. if not V.graph.sizevars.statically_known_equals(
  1433. input_node.get_numel(), reduction_numel
  1434. ):
  1435. return None
  1436. input_node.realize()
  1437. try:
  1438. # finalize layout
  1439. as_storage_and_layout(input_node)
  1440. except NotImplementedError:
  1441. return None
  1442. strides = input_node.get_stride()
  1443. for i, s in enumerate(strides[:-1]):
  1444. if V.graph.sizevars.statically_known_equals(s, 1):
  1445. return i
  1446. return None
  1447. @classmethod
  1448. def _multilayer_wrap_loader(
  1449. cls,
  1450. loader: Callable[..., OpsValue],
  1451. reduction_ranges: Sequence[_IntLike],
  1452. reduction_numel: _IntLike,
  1453. split: _IntLike,
  1454. block_size: _IntLike,
  1455. default: Union[_NumLike, Sequence[_NumLike]],
  1456. input_node: Optional[IRNode] = None,
  1457. ) -> Callable[..., object]:
  1458. dense_index = cls.check_for_split_dense_dim_reindexing(
  1459. reduction_numel, input_node
  1460. )
  1461. reindex = View.dynamic_reshape_indexer(
  1462. reduction_ranges, [reduction_numel], dense_index
  1463. )
  1464. need_mask = not V.graph.sizevars.statically_known_true(
  1465. sympy.Eq(reduction_numel % split, 0)
  1466. )
  1467. def wrapper_fn(
  1468. index: Sequence[Symbol], reduction_index: Sequence[Symbol]
  1469. ) -> OpsValue:
  1470. (reduction_index,) = reduction_index
  1471. *new_index, reduction_block = index
  1472. indices = block_size * reduction_block + reduction_index
  1473. def body() -> OpsValue:
  1474. return loader(new_index, reindex([indices]))
  1475. if need_mask:
  1476. index_dtype = dtype_from_size(reduction_numel)
  1477. mask = ops.lt(
  1478. ops.index_expr(indices, index_dtype),
  1479. ops.index_expr(reduction_numel, index_dtype),
  1480. )
  1481. return ops.masked(mask, body, default)
  1482. else:
  1483. return body()
  1484. return wrapper_fn
  1485. @classmethod
  1486. def _multilayer_wrap_loader_existing_ranges(
  1487. cls,
  1488. loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
  1489. original_ranges: Sequence[Expr],
  1490. original_reduction_ranges: Sequence[Expr],
  1491. new_ranges: Sequence[Integer],
  1492. new_reduction_ranges: Sequence[Integer],
  1493. ) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
  1494. assert all(r == 1 for r in original_ranges), (
  1495. f"Only enabled for numel_hint == 1, found {original_ranges=}"
  1496. )
  1497. reindex = View.dynamic_reshape_indexer(
  1498. original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
  1499. )
  1500. def wrapper_fn(
  1501. merged_index: Sequence[Expr],
  1502. new_reduction_index: Sequence[Expr],
  1503. ) -> OpsValue:
  1504. original_idx = merged_index[: len(original_ranges)]
  1505. new_index = merged_index[len(original_ranges) :]
  1506. return loader(
  1507. original_idx,
  1508. reindex(tuple(new_index) + tuple(new_reduction_index)),
  1509. )
  1510. return wrapper_fn
  1511. @classmethod
  1512. def create_multilayer_helper(
  1513. cls,
  1514. device: torch.device,
  1515. dst_dtype: torch.dtype,
  1516. src_dtype: torch.dtype,
  1517. wrapper_fn: Callable[..., Any],
  1518. original_ranges: Sequence[Expr],
  1519. original_reduction_ranges: Sequence[Expr],
  1520. new_ranges: list[Expr],
  1521. new_reduction_ranges: list[Integer],
  1522. reduction_type: ReductionType,
  1523. split: _IntLike,
  1524. reduction_hint: ReductionHint,
  1525. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1526. """
  1527. Break a large reduction up into multiple smaller reductions
  1528. recursively
  1529. """
  1530. # triton will automatically compute reductions in fp32 if reducing over fp16/bf16
  1531. # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
  1532. # in fp32 and not reduce precision by breaking up the kernel into multiple layers
  1533. intermediate_dtype = (
  1534. dst_dtype
  1535. if dst_dtype not in (torch.float16, torch.bfloat16)
  1536. else torch.float
  1537. )
  1538. intermediate = Reduction.create(
  1539. device,
  1540. intermediate_dtype,
  1541. src_dtype,
  1542. wrapper_fn,
  1543. new_ranges,
  1544. new_reduction_ranges,
  1545. reduction_type,
  1546. reduction_hint,
  1547. )
  1548. intermediate.realize()
  1549. intermediate_loader = intermediate.make_loader()
  1550. def intermediate_fn(
  1551. index: Sequence[_IntLike], reduction_index: Sequence[_IntLike]
  1552. ) -> OpsValue:
  1553. return intermediate_loader([*index, *reduction_index])
  1554. numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges))
  1555. reduction_hint = cls._multilayer_second_step_hint(
  1556. split, numel_hint, reduction_hint
  1557. )
  1558. assert original_ranges == new_ranges[: len(original_ranges)]
  1559. return TensorBox.create(
  1560. Reduction(
  1561. device=device,
  1562. dtype=dst_dtype,
  1563. inner_fn=intermediate_fn,
  1564. ranges=original_ranges,
  1565. reduction_ranges=new_ranges[len(original_ranges) :],
  1566. reduction_type=reduction_type,
  1567. src_dtype=src_dtype,
  1568. reduction_hint=reduction_hint,
  1569. )
  1570. )
  1571. @classmethod
  1572. def create_multilayer(
  1573. cls,
  1574. device: torch.device,
  1575. dst_dtype: torch.dtype,
  1576. src_dtype: torch.dtype,
  1577. inner_fn: Callable[..., Any],
  1578. ranges: Sequence[Expr],
  1579. reduction_ranges: Sequence[Expr],
  1580. reduction_type: ReductionType,
  1581. split: _IntLike,
  1582. reduction_hint: ReductionHint,
  1583. input_node: Optional[IRNode] = None,
  1584. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1585. """
  1586. Break a large reduction up into multiple smaller reductions
  1587. recursively
  1588. """
  1589. # TODO(jansel): realize the reduction so we can do dynamic indexing
  1590. reduction_numel = sympy_product(reduction_ranges)
  1591. block_size = FloorDiv(reduction_numel + (split - 1), split)
  1592. default = cls.default_value(reduction_type, dst_dtype)
  1593. wrapper_fn = cls._multilayer_wrap_loader(
  1594. inner_fn,
  1595. reduction_ranges,
  1596. reduction_numel,
  1597. split,
  1598. block_size,
  1599. default,
  1600. input_node,
  1601. )
  1602. return cls.create_multilayer_helper(
  1603. device,
  1604. dst_dtype,
  1605. src_dtype,
  1606. wrapper_fn,
  1607. ranges,
  1608. reduction_ranges,
  1609. [*ranges, split],
  1610. [block_size],
  1611. reduction_type,
  1612. split,
  1613. reduction_hint,
  1614. )
  1615. @classmethod
  1616. def create_multilayer_existing_ranges(
  1617. cls,
  1618. device: torch.device,
  1619. dst_dtype: torch.dtype,
  1620. src_dtype: torch.dtype,
  1621. inner_fn: Callable[..., Any],
  1622. original_ranges: Sequence[Expr],
  1623. original_reduction_ranges: Sequence[Expr],
  1624. new_ranges: list[Integer],
  1625. new_reduction_ranges: list[Integer],
  1626. reduction_type: ReductionType,
  1627. reduction_hint: ReductionHint,
  1628. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1629. """
  1630. Break a large reduction up into multiple smaller reductions
  1631. recursively
  1632. """
  1633. wrapper_fn = cls._multilayer_wrap_loader_existing_ranges(
  1634. inner_fn,
  1635. original_ranges,
  1636. original_reduction_ranges,
  1637. new_ranges,
  1638. new_reduction_ranges,
  1639. )
  1640. return cls.create_multilayer_helper(
  1641. device,
  1642. dst_dtype,
  1643. src_dtype,
  1644. wrapper_fn,
  1645. original_ranges,
  1646. original_reduction_ranges,
  1647. [*original_ranges, *new_ranges],
  1648. new_reduction_ranges,
  1649. reduction_type,
  1650. -1,
  1651. reduction_hint,
  1652. )
  1653. def _fixed_indexer(
  1654. size: Sequence[int],
  1655. stride: Optional[Sequence[int]] = None,
  1656. offset: Expr = Integer(0),
  1657. ) -> Callable[[Sequence[Expr]], Expr]:
  1658. """A closure containing math to read a given element"""
  1659. def indexer(index: Sequence[int]) -> int:
  1660. assert stride is not None and len(index) == len(stride)
  1661. assert len(index) == len(size)
  1662. result = offset
  1663. for idx, st, sz in zip(index, stride, size):
  1664. if sz != 1:
  1665. result = result + idx * st
  1666. return result
  1667. return indexer
  1668. INNER_FN_TY: TypeAlias = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]
  1669. class MultiOutputReduction(Reduction):
  1670. output_index: int
  1671. def __init__(
  1672. self,
  1673. device: torch.device,
  1674. dst_dtype: torch.dtype,
  1675. inner_fns: Union[INNER_FN_TY, Sequence[INNER_FN_TY]],
  1676. ranges: Sequence[Integer],
  1677. reduction_ranges: Sequence[Integer],
  1678. reduction_type: ReductionType,
  1679. src_dtype: torch.dtype,
  1680. reduction_hint: ReductionHint,
  1681. output_index: int,
  1682. ):
  1683. if callable(inner_fns):
  1684. inner_fns = (inner_fns,)
  1685. loader: Callable[[Sequence[Expr], Sequence[Expr]], Any]
  1686. if len(inner_fns) == 1:
  1687. loader = inner_fns[0]
  1688. else:
  1689. def loader(
  1690. idx: Sequence[Expr], reduction_idx: Sequence[Expr]
  1691. ) -> tuple[OpsValue, ...]:
  1692. return tuple(fn(idx, reduction_idx) for fn in inner_fns)
  1693. super().__init__(
  1694. device=device,
  1695. dtype=dst_dtype,
  1696. inner_fn=loader,
  1697. ranges=ranges,
  1698. reduction_ranges=reduction_ranges,
  1699. reduction_type=reduction_type,
  1700. src_dtype=src_dtype,
  1701. reduction_hint=reduction_hint,
  1702. )
  1703. self.output_index = output_index
  1704. def store_reduction(
  1705. self,
  1706. output_name: Optional[str],
  1707. indexer: Callable[[Sequence[Expr]], Never],
  1708. vars: Sequence[Expr],
  1709. reduction_vars: Sequence[Symbol],
  1710. ) -> Any:
  1711. values = ops.reduction(
  1712. self.dtype,
  1713. self.src_dtype,
  1714. self.reduction_type,
  1715. self.inner_fn(vars, reduction_vars),
  1716. )
  1717. assert isinstance(values, (tuple, list)), type(values)
  1718. value = values[self.output_index]
  1719. return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
  1720. class OnlineSoftmaxReduction(MultiOutputReduction):
  1721. @classmethod
  1722. def create( # type: ignore[override]
  1723. cls,
  1724. device: torch.device,
  1725. dst_dtype: torch.dtype,
  1726. src_dtype: torch.dtype,
  1727. inner_fn: Callable[..., Any],
  1728. ranges: Sequence[Expr],
  1729. reduction_ranges: Sequence[Expr],
  1730. num_output: int,
  1731. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  1732. input_node: Optional[IRNode] = None,
  1733. ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]:
  1734. """
  1735. Create the reduction disregarding splitting.
  1736. """
  1737. results = tuple(
  1738. TensorBox.create(
  1739. MultiOutputReduction(
  1740. device,
  1741. dst_dtype,
  1742. inner_fn,
  1743. ranges,
  1744. reduction_ranges,
  1745. "online_softmax_reduce",
  1746. src_dtype,
  1747. reduction_hint,
  1748. output_idx,
  1749. )
  1750. )
  1751. for output_idx in range(num_output)
  1752. )
  1753. for t in results:
  1754. t.realize()
  1755. return results
  1756. class WelfordReduction(MultiOutputReduction):
  1757. @classmethod
  1758. def create( # type: ignore[override]
  1759. cls,
  1760. device: torch.device,
  1761. dtype: torch.dtype,
  1762. inner_fns: Sequence[Callable[..., Any]],
  1763. ranges: list[Integer],
  1764. reduction_ranges: list[Integer],
  1765. reduction_type: ReductionType,
  1766. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  1767. ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]:
  1768. assert reduction_type in ("welford_reduce", "welford_combine")
  1769. reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
  1770. def const(val: int) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1771. def inner_fn(idx: Sequence[Expr]) -> OpsValue:
  1772. return ops.constant(
  1773. val,
  1774. dtype,
  1775. )
  1776. return Pointwise.create(
  1777. device=device,
  1778. dtype=dtype,
  1779. inner_fn=inner_fn,
  1780. ranges=list(ranges),
  1781. )
  1782. if reduction_numel == 0:
  1783. mean = const(0)
  1784. m2 = const(0)
  1785. weight = const(0)
  1786. return mean, m2, weight
  1787. if reduction_numel == 1:
  1788. def copy(
  1789. loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
  1790. ) -> Union[TensorBox, ShapeAsConstantBuffer]:
  1791. def inner_fn(idx: Sequence[Expr]) -> OpsValue:
  1792. reduction_index = [sympy.S.Zero for _ in reduction_ranges]
  1793. return loader(idx, reduction_index)
  1794. return Pointwise.create(
  1795. device=device,
  1796. dtype=dtype,
  1797. inner_fn=inner_fn,
  1798. ranges=list(ranges),
  1799. )
  1800. if reduction_type == "welford_reduce":
  1801. return copy(inner_fns[0]), const(0), const(1)
  1802. else:
  1803. return tuple(copy(fn) for fn in inner_fns)
  1804. # TODO: Unrolled reduction
  1805. # if (
  1806. # isinstance(reduction_numel, Integer)
  1807. # and V.graph.sizevars.size_hint(reduction_numel)
  1808. # < config.unroll_reductions_threshold
  1809. # and sympy_product(ranges) != 1
  1810. # ):
  1811. # return Pointwise.create(
  1812. # device,
  1813. # dst_dtype,
  1814. # cls._unroll_reduction_fn(
  1815. # inner_fn, reduction_ranges, reduction_type, src_dtype,
  1816. # ),
  1817. # ranges,
  1818. # )
  1819. # triton doesn't support reduce to single element well, so break it up
  1820. hint, split = Reduction.num_splits(
  1821. device,
  1822. dtype,
  1823. dtype,
  1824. inner_fns[0],
  1825. ranges,
  1826. reduction_ranges,
  1827. reduction_type=reduction_type,
  1828. reduction_numel=reduction_numel,
  1829. )
  1830. # intermediate reduction in split can contain complex indexing,
  1831. # and num_splits will fail to correctly set the hint
  1832. # reuse the passed hint if available
  1833. if reduction_hint == ReductionHint.DEFAULT:
  1834. reduction_hint = hint
  1835. if split > 1:
  1836. # triton doesn't support reduce to single element well, so break it up
  1837. return cls.create_multilayer(
  1838. device,
  1839. dtype,
  1840. inner_fns,
  1841. ranges,
  1842. reduction_ranges,
  1843. reduction_type,
  1844. split,
  1845. reduction_hint,
  1846. )
  1847. results = [
  1848. TensorBox.create(
  1849. WelfordReduction(
  1850. device,
  1851. dtype,
  1852. inner_fns,
  1853. ranges,
  1854. reduction_ranges,
  1855. reduction_type,
  1856. dtype,
  1857. reduction_hint,
  1858. output_idx,
  1859. )
  1860. )
  1861. for output_idx in range(3)
  1862. ]
  1863. for t in results:
  1864. t.realize()
  1865. return results
  1866. @staticmethod
  1867. def default_value(
  1868. reduction_type: str, dtype: torch.dtype
  1869. ) -> Union[_NumLike, Sequence[_NumLike]]:
  1870. return (0, 0, 0)
  1871. @classmethod
  1872. def create_multilayer( # type: ignore[override]
  1873. cls,
  1874. device: torch.device,
  1875. dtype: torch.dtype,
  1876. inner_fns: Sequence[Callable[..., Any]],
  1877. ranges: list[Integer],
  1878. reduction_ranges: list[Integer],
  1879. reduction_type: ReductionType,
  1880. split: _IntLike,
  1881. reduction_hint: ReductionHint,
  1882. ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]:
  1883. """
  1884. Break a large reduction up into multiple smaller reductions
  1885. recursively
  1886. """
  1887. reduction_numel = sympy_product(reduction_ranges)
  1888. need_mask = not V.graph.sizevars.statically_known_true(
  1889. sympy.Eq(reduction_numel % split, 0)
  1890. )
  1891. if need_mask and reduction_type != "welford_combine":
  1892. # If we need mask, then "welford_reduce" doesn't work because
  1893. # masked inputs shouldn't count towards the welford weight
  1894. def constant(
  1895. idx: Sequence[Expr], reduction_idx: Sequence[Expr], value: int
  1896. ) -> OpsValue:
  1897. return ops.constant(value, dtype)
  1898. return cls.create_multilayer(
  1899. device=device,
  1900. dtype=dtype,
  1901. inner_fns=(
  1902. inner_fns[0],
  1903. partial(constant, value=0),
  1904. partial(constant, value=1),
  1905. ),
  1906. ranges=ranges,
  1907. reduction_ranges=reduction_ranges,
  1908. reduction_type="welford_combine",
  1909. split=split,
  1910. reduction_hint=reduction_hint,
  1911. )
  1912. block_size = FloorDiv(reduction_numel + (split - 1), split)
  1913. intermediates = WelfordReduction.create(
  1914. device,
  1915. dtype,
  1916. tuple(
  1917. cls._multilayer_wrap_loader(
  1918. loader,
  1919. reduction_ranges,
  1920. reduction_numel,
  1921. split,
  1922. block_size,
  1923. default=0,
  1924. )
  1925. for loader in inner_fns
  1926. ),
  1927. [*ranges, split],
  1928. [block_size],
  1929. reduction_type,
  1930. reduction_hint,
  1931. )
  1932. for i in intermediates:
  1933. i.realize()
  1934. def intermediate_loader_fn(
  1935. index: Sequence[Expr],
  1936. reduction_index: Sequence[Expr],
  1937. loader: Callable[[Sequence[Expr]], OpsValue],
  1938. ) -> OpsValue:
  1939. return loader([*index, *reduction_index])
  1940. numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
  1941. reduction_hint = cls._multilayer_second_step_hint(
  1942. split, numel_hint, reduction_hint
  1943. )
  1944. return WelfordReduction.create(
  1945. device,
  1946. dtype,
  1947. tuple(
  1948. partial(intermediate_loader_fn, loader=i.make_loader())
  1949. for i in intermediates
  1950. ),
  1951. ranges,
  1952. [split],
  1953. # welford_reduce turns one input into three outputs, which are combined with welford_combine
  1954. "welford_combine",
  1955. reduction_hint,
  1956. )
  1957. @ir_dataclass
  1958. class Scan(Loops):
  1959. scan_ranges: list[Integer]
  1960. size: list[Integer]
  1961. combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]]
  1962. reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]]
  1963. reduction_hint: ReductionHint
  1964. output_index: int
  1965. # output_index indexes the following tuples
  1966. dtypes: tuple[torch.dtype, ...]
  1967. inner_fns: tuple[Callable[..., Any], ...]
  1968. # HACK we mimic reduction
  1969. @cache_on_self_and_args("Scan")
  1970. def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  1971. # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
  1972. # need to explicitly represent the closure so we can pull out unbacked
  1973. # symbols here
  1974. return (
  1975. super().get_free_symbol_uses(unbacked_only)
  1976. | OrderedSet().union(
  1977. *(get_free_symbols(e, unbacked_only) for e in self.scan_ranges)
  1978. )
  1979. | OrderedSet().union(
  1980. *(get_free_symbols(e, unbacked_only) for e in self.size)
  1981. )
  1982. )
  1983. def __post_init__(self) -> None:
  1984. assert len(self.ranges) + len(self.scan_ranges) == len(self.size)
  1985. super().__post_init__()
  1986. def store_reduction(
  1987. self,
  1988. output_name: Optional[str],
  1989. indexer: Callable[[Sequence[_IntLike]], Never],
  1990. vars: Sequence[Expr],
  1991. scan_vars: Sequence[Symbol],
  1992. ) -> Any:
  1993. idx = self.reindex(vars, scan_vars)
  1994. values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
  1995. result = ops.scan(self.dtypes, self.combine_fn, values)
  1996. return ops.store(
  1997. output_name or "unnamed", indexer(idx), result[self.output_index]
  1998. )
  1999. def get_reduction_type(self) -> Optional[str]:
  2000. # return self.scan_op
  2001. return "custom"
  2002. def get_reduction_size(self) -> Sequence[Expr]:
  2003. return self.scan_ranges
  2004. def get_size(self) -> Sequence[Expr]:
  2005. return self.size
  2006. def get_pointwise_size(self) -> Sequence[Expr]:
  2007. return self.ranges
  2008. def index_length(self) -> int:
  2009. return len(self.ranges) + len(self.scan_ranges)
  2010. def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]:
  2011. index = self._index(self.ranges)
  2012. rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
  2013. idx = self.reindex(index, rindex)
  2014. return (idx,)
  2015. def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2016. index = self._index(self.ranges)
  2017. rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
  2018. idx = self.reindex(index, rindex)
  2019. return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only)
  2020. @classmethod
  2021. def create( # type: ignore[override]
  2022. cls,
  2023. device: torch.device,
  2024. dtypes: tuple[torch.dtype, ...],
  2025. inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...],
  2026. size: list[Integer],
  2027. axis: int,
  2028. combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
  2029. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  2030. *,
  2031. # Whether we have the option to fallback to aten
  2032. can_fallback_to_aten: bool = True,
  2033. **kwargs: Any,
  2034. ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]:
  2035. pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
  2036. scan_ranges = [size[axis]]
  2037. if not V.graph.has_feature(device, BackendFeature.SCAN):
  2038. return [None] * len(dtypes)
  2039. if len(dtypes) > 1 and not V.graph.has_feature(
  2040. device, BackendFeature.TUPLE_REDUCTION
  2041. ):
  2042. return [None] * len(dtypes)
  2043. sizevars = V.graph.sizevars
  2044. scan_numel = sizevars.simplify(sympy_product(scan_ranges))
  2045. assert len(dtypes) == len(inner_fns)
  2046. # Scan with a single element is just a copy
  2047. if sizevars.statically_known_true(sympy.Le(scan_numel, 1)):
  2048. return [
  2049. Pointwise.create(
  2050. device=device,
  2051. dtype=dtypes[output_index],
  2052. inner_fn=inner_fns[output_index],
  2053. ranges=size,
  2054. )
  2055. for output_index in range(len(dtypes))
  2056. ]
  2057. reduction_hint, num_splits = cls.num_splits(
  2058. device=device,
  2059. dtype=dtypes[0],
  2060. inner_fn=inner_fns[0],
  2061. axis=axis,
  2062. pointwise_ranges=pointwise_ranges,
  2063. scan_ranges=scan_ranges,
  2064. combine_fn=combine_fn,
  2065. scan_numel=scan_numel,
  2066. )
  2067. scan_type = Scan
  2068. if num_splits > 1:
  2069. supports_split = (
  2070. torch.version.hip is None or (has_triton and triton_version >= "3.3.0")
  2071. ) and (len(dtypes) == 1)
  2072. if not supports_split:
  2073. if can_fallback_to_aten:
  2074. # Fallback to ATen
  2075. return [None] * len(dtypes)
  2076. else:
  2077. num_splits = 1
  2078. else:
  2079. scan_type = SplitScan
  2080. def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]:
  2081. assert len(scan_index) == len(scan_ranges)
  2082. assert len(index) == len(pointwise_ranges)
  2083. return [*index[:axis], *scan_index, *index[axis:]]
  2084. results = [
  2085. TensorBox.create(
  2086. scan_type(
  2087. device=device,
  2088. dtype=dtypes[output_index],
  2089. dtypes=dtypes,
  2090. inner_fn=inner_fns[output_index],
  2091. inner_fns=inner_fns,
  2092. size=size,
  2093. ranges=pointwise_ranges,
  2094. scan_ranges=scan_ranges,
  2095. combine_fn=combine_fn,
  2096. reindex=reindex,
  2097. reduction_hint=reduction_hint,
  2098. output_index=output_index,
  2099. **kwargs,
  2100. )
  2101. )
  2102. for output_index in range(len(dtypes))
  2103. ]
  2104. for result in results:
  2105. result.realize()
  2106. return results
  2107. @classmethod
  2108. def num_splits(
  2109. cls,
  2110. device: torch.device,
  2111. dtype: torch.dtype,
  2112. inner_fn: Callable[[Sequence[Expr]], OpsValue],
  2113. axis: int,
  2114. pointwise_ranges: list[Integer],
  2115. scan_ranges: list[Integer],
  2116. combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
  2117. scan_numel: Expr,
  2118. ) -> tuple[ReductionHint, _IntLike]:
  2119. # TODO: custom splitting heuristic for scan
  2120. def wrapper_fn(idx: Sequence[Expr], reduction_idx: Sequence[Expr]) -> OpsValue:
  2121. return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]])
  2122. return Reduction.num_splits(
  2123. device=device,
  2124. dst_dtype=dtype,
  2125. src_dtype=dtype,
  2126. inner_fn=wrapper_fn,
  2127. ranges=pointwise_ranges,
  2128. reduction_ranges=scan_ranges,
  2129. reduction_type="scan",
  2130. reduction_numel=scan_numel,
  2131. )
  2132. # This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA.
  2133. @ir_dataclass
  2134. class SplitScan(Scan):
  2135. pass
  2136. @ir_dataclass
  2137. class Sort(Loops):
  2138. # Sorts a tuple of key, value pairs
  2139. sort_ranges: list[Integer]
  2140. size: list[Integer]
  2141. reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]]
  2142. reduction_hint: ReductionHint
  2143. output_index: int
  2144. # output_index indexes the following tuples
  2145. dtypes: tuple[torch.dtype, ...]
  2146. inner_fns: tuple[Callable[..., Any], ...]
  2147. stable: bool
  2148. descending: bool
  2149. # HACK we mimic reduction
  2150. @cache_on_self_and_args("Sort")
  2151. def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2152. return (
  2153. super().get_free_symbol_uses(unbacked_only)
  2154. | OrderedSet().union(
  2155. *(get_free_symbols(e, unbacked_only) for e in self.sort_ranges)
  2156. )
  2157. | OrderedSet().union(
  2158. *(get_free_symbols(e, unbacked_only) for e in self.size)
  2159. )
  2160. )
  2161. def __post_init__(self) -> None:
  2162. assert len(self.ranges) + len(self.sort_ranges) == len(self.size)
  2163. super().__post_init__()
  2164. def store_reduction(
  2165. self,
  2166. output_name: Optional[str],
  2167. indexer: Callable[[Sequence[Expr]], Expr],
  2168. vars: Sequence[Expr],
  2169. reduction_vars: Sequence[Expr],
  2170. ) -> Any:
  2171. idx = self.reindex(vars, reduction_vars)
  2172. values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
  2173. result = ops.sort(self.dtypes, values, self.stable, self.descending)
  2174. return ops.store(
  2175. output_name or "unnamed", indexer(idx), result[self.output_index]
  2176. )
  2177. def get_reduction_type(self) -> Optional[str]:
  2178. return "sort"
  2179. def get_reduction_size(self) -> Sequence[Expr]:
  2180. return self.sort_ranges
  2181. def get_size(self) -> Sequence[Expr]:
  2182. return self.size
  2183. def get_pointwise_size(self) -> Sequence[Expr]:
  2184. return self.ranges
  2185. def index_length(self) -> int:
  2186. return len(self.ranges) + len(self.sort_ranges)
  2187. def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
  2188. index = self._index(self.ranges)
  2189. rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
  2190. idx = self.reindex(index, rindex)
  2191. return (idx,)
  2192. def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2193. index = self._index(self.ranges)
  2194. rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
  2195. idx = self.reindex(index, rindex)
  2196. return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only)
  2197. @classmethod
  2198. def create( # type: ignore[override]
  2199. cls,
  2200. device: torch.device,
  2201. dtypes: tuple[torch.dtype, ...],
  2202. inner_fns: tuple[Callable[[list[Expr]], Any], ...],
  2203. size: list[Integer],
  2204. axis: int,
  2205. stable: bool,
  2206. descending: bool,
  2207. reduction_hint: ReductionHint = ReductionHint.DEFAULT,
  2208. **kwargs: Any,
  2209. ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]:
  2210. pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
  2211. sort_ranges = [size[axis]]
  2212. if not V.graph.has_feature(device, BackendFeature.SORT):
  2213. return [None] * len(dtypes)
  2214. sizevars = V.graph.sizevars
  2215. sort_numel = sizevars.simplify(sympy_product(sort_ranges))
  2216. # Heuristic, smallest rblock where triton usually outperforms aten.sort
  2217. # It also isn't bandwidth bound so fusion is unlikely to help.
  2218. max_rblock = 512
  2219. is_persistent_kernel = (
  2220. config.triton.persistent_reductions
  2221. and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock))
  2222. )
  2223. if not is_persistent_kernel:
  2224. # We only support persistent triton kernels
  2225. return [None] * len(dtypes)
  2226. assert len(dtypes) == len(inner_fns)
  2227. # Sort with a single element is just a copy
  2228. if sizevars.statically_known_true(sympy.Le(sort_numel, 1)):
  2229. return [
  2230. Pointwise.create(
  2231. device=device,
  2232. dtype=dtypes[output_index],
  2233. inner_fn=inner_fns[output_index],
  2234. ranges=size,
  2235. )
  2236. for output_index in range(len(dtypes))
  2237. ]
  2238. def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]:
  2239. assert len(sort_index) == len(sort_ranges)
  2240. assert len(index) == len(pointwise_ranges)
  2241. return [*index[:axis], *sort_index, *index[axis:]]
  2242. results = [
  2243. TensorBox.create(
  2244. Sort(
  2245. device=device,
  2246. dtype=dtypes[output_index],
  2247. dtypes=dtypes,
  2248. inner_fn=inner_fns[output_index],
  2249. inner_fns=inner_fns,
  2250. size=size,
  2251. ranges=pointwise_ranges,
  2252. sort_ranges=sort_ranges,
  2253. reindex=reindex,
  2254. reduction_hint=reduction_hint,
  2255. output_index=output_index,
  2256. stable=stable,
  2257. descending=descending,
  2258. **kwargs,
  2259. )
  2260. )
  2261. for output_index in range(len(dtypes))
  2262. ]
  2263. for result in results:
  2264. result.realize()
  2265. return results
  2266. def is_storage_and_layout(x: IRNode) -> bool:
  2267. try:
  2268. as_storage_and_layout(x, freeze=False)
  2269. return True
  2270. except NotImplementedError:
  2271. return False
  2272. def is_contiguous_storage_and_layout(x: IRNode) -> bool:
  2273. try:
  2274. _buffer, layout = as_storage_and_layout(x, freeze=False)
  2275. # pad the stride here so we will NOT claim an tensor as contiguous
  2276. # if a padding is gonna happen.
  2277. if layout.should_pad_strides():
  2278. layout.pad_strides()
  2279. return layout.is_contiguous()
  2280. except NotImplementedError:
  2281. return False
  2282. def as_storage_and_layout(
  2283. x: IRNode,
  2284. freeze: bool = True,
  2285. want_contiguous: bool = False,
  2286. stride_order: Optional[Sequence[Union[int, Integer]]] = None,
  2287. allow_padding: bool = False,
  2288. exact_strides: Optional[Sequence[Union[int, Integer]]] = None,
  2289. ) -> tuple[StorageBox, Layout]:
  2290. """
  2291. Try to simplify x into a StorageBox and a Layout.
  2292. allow_padding only affect how we apply stride_order. When allow_padding
  2293. is True, we have the freedom to add padding when applying the stride_order.
  2294. """
  2295. if isinstance(x, TensorBox):
  2296. return as_storage_and_layout(
  2297. x.data,
  2298. freeze=freeze,
  2299. want_contiguous=want_contiguous,
  2300. stride_order=stride_order,
  2301. allow_padding=allow_padding,
  2302. exact_strides=exact_strides,
  2303. )
  2304. if isinstance(x, StorageBox):
  2305. _, layout = as_storage_and_layout(
  2306. x.data,
  2307. freeze=freeze,
  2308. want_contiguous=want_contiguous,
  2309. stride_order=stride_order,
  2310. allow_padding=allow_padding,
  2311. exact_strides=exact_strides,
  2312. )
  2313. return x, x.data.get_layout()
  2314. if isinstance(x, Buffer):
  2315. if freeze:
  2316. if want_contiguous:
  2317. x.freeze_layout()
  2318. assert x.get_layout().is_contiguous()
  2319. elif stride_order is not None:
  2320. x.freeze_layout_with_stride_order(
  2321. stride_order, allow_padding=allow_padding
  2322. )
  2323. elif exact_strides is not None:
  2324. x.freeze_layout_with_exact_strides(
  2325. exact_strides, allow_padding=allow_padding
  2326. )
  2327. else:
  2328. x.decide_layout()
  2329. return StorageBox(x), x.get_layout()
  2330. if isinstance(x, ReinterpretView):
  2331. # making the base of x contiguous or stride_ordered will not necessarily make
  2332. # the ReinterpretView either, so don't pass along those arguments
  2333. buffer, _ = as_storage_and_layout(
  2334. x.data,
  2335. freeze=freeze,
  2336. )
  2337. return buffer, x.layout
  2338. raise NotImplementedError
  2339. def is_stride_order_storage_and_layout(
  2340. x: IRNode, stride_order: Sequence[Union[int, Integer]]
  2341. ) -> bool:
  2342. try:
  2343. _buffer, layout = as_storage_and_layout(x, freeze=False)
  2344. return layout.is_stride_ordered(stride_order)
  2345. except NotImplementedError:
  2346. return False
  2347. def is_unaligned(node: IRNode) -> bool:
  2348. if isinstance(node, (TensorBox, StorageBox)):
  2349. return is_unaligned(node.data)
  2350. if isinstance(node, ReinterpretView):
  2351. layout = node.layout
  2352. has_unaligned_layout = not V.graph.sizevars.statically_known_multiple_of(
  2353. layout.offset * get_dtype_size(layout.dtype), GPU_ALIGN_BYTES
  2354. )
  2355. return is_unaligned(node.data) or has_unaligned_layout
  2356. if isinstance(node, Buffer):
  2357. return node.get_name() in V.graph.unaligned_buffers
  2358. # assume to be aligned otherwise
  2359. return False
  2360. @ir_dataclass
  2361. class BaseView(IRNode):
  2362. data: IRNode
  2363. @cache_on_self_and_args("BaseView")
  2364. def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
  2365. return self.data.get_free_symbol_uses(unbacked_only)
  2366. def make_reindexer(self) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2367. raise NotImplementedError(f"make_reindexer NYI on {self}")
  2368. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  2369. inner = self.data.make_indexer()
  2370. reindex = self.make_reindexer()
  2371. def indexer(idx: Sequence[Expr]) -> Expr:
  2372. return inner(reindex(idx))
  2373. return indexer
  2374. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  2375. inner = self.data.make_loader()
  2376. reindex = self.make_reindexer()
  2377. def loader(idx: Sequence[Expr]) -> OpsValue:
  2378. return inner(reindex(idx))
  2379. return loader
  2380. @property
  2381. def dtype(self) -> torch.dtype:
  2382. return self.data.get_dtype()
  2383. def get_layout(self) -> Layout:
  2384. return self.data.get_layout()
  2385. def get_device(self) -> Optional[torch.device]:
  2386. return self.data.get_device()
  2387. def get_origin_node(self) -> Optional[torch.fx.Node]:
  2388. return None
  2389. def get_name(self) -> str:
  2390. return self.data.get_name()
  2391. def get_pointwise_size(self) -> Sequence[Expr]:
  2392. return self.get_size()
  2393. def mark_reuse(self, users: int) -> None:
  2394. return self.data.mark_reuse(users)
  2395. def has_exceeded_max_reads(self) -> bool:
  2396. return self.data.has_exceeded_max_reads()
  2397. def realize(self) -> Optional[str]:
  2398. return self.data.realize()
  2399. def realize_hint(self) -> None:
  2400. self.data.realize_hint()
  2401. def get_storage_numel(self) -> _IntLike:
  2402. return self.data.get_storage_numel()
  2403. def is_extern(self) -> bool:
  2404. return self.data.is_extern()
  2405. def is_module_buffer(self) -> bool:
  2406. assert isinstance(self.data, BaseView), type(self.data)
  2407. return self.data.is_module_buffer()
  2408. def get_read_names(self) -> OrderedSet[str]:
  2409. return self.data.get_read_names()
  2410. def get_reads(self) -> OrderedSet[Dep]:
  2411. with patch.object(FlexibleLayout, "allow_indexing", True):
  2412. return extract_read_writes(
  2413. self.make_loader(),
  2414. self.get_size(),
  2415. ).reads
  2416. def unwrap_view(self) -> IRNode:
  2417. x: IRNode = self
  2418. while isinstance(x, BaseView):
  2419. x = x.data
  2420. return x
  2421. def constant_to_device(self, device: torch.device) -> IRNode:
  2422. """Move this to a given device. Requires that all reads are to constants."""
  2423. loader = self.make_loader()
  2424. loader = patch.object(ConstantBuffer, "override_device", device)(loader)
  2425. return Pointwise(
  2426. device=device,
  2427. dtype=self.get_dtype(),
  2428. inner_fn=loader,
  2429. ranges=self.get_size(),
  2430. )
  2431. @ir_dataclass
  2432. class ExpandView(BaseView):
  2433. size: Sequence[Expr]
  2434. @staticmethod
  2435. def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLike]:
  2436. """Replace `-1` with correct sizes"""
  2437. sizevars = V.graph.sizevars
  2438. new_size = [sympy.expand(s) for s in new_size]
  2439. old_size = x.get_size()
  2440. old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
  2441. assert len(new_size) == len(old_size)
  2442. for i in range(len(new_size)):
  2443. if new_size[i] == -1:
  2444. assert old_size[i] is not None
  2445. new_size[i] = old_size[i]
  2446. elif old_size[i] is None or V.graph.sizevars.is_size_one_or_false(
  2447. old_size[i]
  2448. ):
  2449. pass
  2450. else:
  2451. # Sanity check: Expect broadcast compatibility
  2452. #
  2453. # NB: new_size[i] == old_size[i] is expected to already be
  2454. # guarded because the meta formula was expected to have taught
  2455. # us this equality.
  2456. assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
  2457. "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
  2458. )
  2459. return new_size
  2460. @classmethod
  2461. def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView:
  2462. new_size = cls._normalize_size(x, new_size)
  2463. if is_storage_and_layout(x):
  2464. storage, old_layout = as_storage_and_layout(x)
  2465. skip = len(new_size) - len(old_layout.size)
  2466. assert skip >= 0
  2467. new_stride = [sympy.S.Zero] * skip
  2468. for stride, size in zip(old_layout.stride, old_layout.size):
  2469. new_stride.append(
  2470. stride
  2471. if not V.graph.sizevars.is_size_one_or_false(size)
  2472. else sympy.S.Zero
  2473. )
  2474. new_layout = FixedLayout(
  2475. old_layout.device,
  2476. old_layout.dtype,
  2477. list(new_size),
  2478. new_stride,
  2479. old_layout.offset,
  2480. old_layout.is_pinned,
  2481. )
  2482. return ReinterpretView(data=storage, layout=new_layout)
  2483. return ExpandView(data=x, size=new_size)
  2484. def get_size(self) -> Sequence[Expr]:
  2485. return self.size
  2486. def make_reindexer(
  2487. self,
  2488. ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2489. target = self.get_size()
  2490. actual = self.data.get_size()
  2491. skip = len(target) - len(actual)
  2492. def reindex(
  2493. index: Sequence[Expr],
  2494. ) -> Sequence[Expr]:
  2495. index = list(index[skip:])
  2496. assert len(index) == len(actual)
  2497. for i in range(len(actual)):
  2498. if actual[i] == 1:
  2499. # zero out broadcast dimension
  2500. index[i] = sympy.S.Zero
  2501. return index
  2502. return reindex
  2503. @ir_dataclass
  2504. class PermuteView(BaseView):
  2505. dims: list[Expr]
  2506. @classmethod
  2507. def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView:
  2508. dims = cls._map_neg_dims(dims)
  2509. assert OrderedSet(dims) == OrderedSet(range(len(dims)))
  2510. if is_storage_and_layout(x):
  2511. storage, old_layout = as_storage_and_layout(x)
  2512. new_layout = FixedLayout(
  2513. old_layout.device,
  2514. old_layout.dtype,
  2515. [old_layout.size[i] for i in dims],
  2516. [old_layout.stride[i] for i in dims],
  2517. old_layout.offset,
  2518. old_layout.is_pinned,
  2519. )
  2520. return ReinterpretView(data=storage, layout=new_layout)
  2521. return PermuteView(data=x, dims=dims)
  2522. @classmethod
  2523. def _map_neg_dims(cls, dims: Sequence[int]) -> list[int]:
  2524. return [dim if dim >= 0 else len(dims) + dim for dim in dims]
  2525. def get_size(self) -> Sequence[Expr]:
  2526. assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet(
  2527. range(len(self.dims))
  2528. )
  2529. size = self.data.get_size()
  2530. return [size[i] for i in self.dims]
  2531. def make_reindexer(
  2532. self,
  2533. ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2534. inv = {j: i for i, j in enumerate(self.dims)}
  2535. inv = [inv[i] for i in range(len(self.dims))]
  2536. assert OrderedSet(inv) == OrderedSet(range(len(self.dims)))
  2537. def reindex(
  2538. index: Sequence[Expr],
  2539. ) -> Sequence[Expr]:
  2540. return [index[i] for i in inv]
  2541. return reindex
  2542. @ir_dataclass
  2543. class SqueezeView(BaseView):
  2544. @classmethod
  2545. def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode:
  2546. if is_storage_and_layout(x):
  2547. storage, old_layout = as_storage_and_layout(x)
  2548. new_size = []
  2549. new_stride = []
  2550. if dim is not None:
  2551. assert isinstance(dim, int), type(dim)
  2552. assert 0 <= dim and dim < len(old_layout.size)
  2553. for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
  2554. if dim is None:
  2555. if size != 1:
  2556. new_size.append(size)
  2557. new_stride.append(stride)
  2558. else:
  2559. if i != dim:
  2560. new_size.append(size)
  2561. new_stride.append(stride)
  2562. else:
  2563. assert size == 1, "expected squeezed size to be 1"
  2564. new_layout = FixedLayout(
  2565. old_layout.device,
  2566. old_layout.dtype,
  2567. new_size,
  2568. new_stride,
  2569. old_layout.offset,
  2570. old_layout.is_pinned,
  2571. )
  2572. return ReinterpretView(data=storage, layout=new_layout)
  2573. if dim is None:
  2574. # redirect to a generic view
  2575. return View.create(x, [s for s in x.get_size() if s != 1])
  2576. else:
  2577. assert x.get_size()[dim] == 1
  2578. return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
  2579. @staticmethod
  2580. def squeezer(
  2581. size: Sequence[Expr],
  2582. ) -> tuple[list[int], Callable[[Sequence[Expr]], tuple[Expr]]]:
  2583. new_size = [s for s in size if s != 1]
  2584. not_one = [i for i, s in enumerate(size) if s != 1]
  2585. length = len(size)
  2586. def reindex(index: Sequence[Expr]) -> tuple[Expr]:
  2587. assert len(index) == len(not_one), f"{index} {not_one}"
  2588. new_index = [sympy.S.Zero] * length
  2589. for idx, s in zip(not_one, index):
  2590. new_index[idx] = s
  2591. return tuple(new_index)
  2592. return new_size, reindex
  2593. def __init__(self, data: Any) -> None:
  2594. raise AssertionError("use SqueezeView.create()")
  2595. @ir_dataclass
  2596. class GenericView(BaseView):
  2597. size: Sequence[Expr]
  2598. reindex: Callable[[Sequence[Expr]], Sequence[Expr]]
  2599. def make_reindexer(
  2600. self,
  2601. ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2602. return self.reindex
  2603. def reindex_str(self) -> str:
  2604. index_old = [
  2605. sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size))
  2606. ]
  2607. index_new = list(self.reindex(index_old))
  2608. return f"lambda {', '.join(map(str, index_old))}: {index_new}"
  2609. def __str__(self) -> str:
  2610. return self.str_helper(
  2611. [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
  2612. )
  2613. __repr__ = __str__
  2614. @classmethod
  2615. def create(
  2616. cls,
  2617. x: IRNode,
  2618. new_size: Sequence[Expr],
  2619. reindex: Callable[[Sequence[Expr]], Sequence[Expr]],
  2620. ) -> BaseView:
  2621. return cls(data=x, size=list(new_size), reindex=reindex)
  2622. def get_size(self) -> Sequence[Expr]:
  2623. return self.size
  2624. @ir_dataclass
  2625. class View(GenericView):
  2626. @staticmethod
  2627. def handle_negative_index(idx: Expr, size: Expr) -> Expr:
  2628. idx = sympy.expand(idx)
  2629. size = sympy.expand(size)
  2630. evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr
  2631. if evaluate_expr(sympy.Lt(idx, 0)):
  2632. idx = idx + size
  2633. return idx
  2634. @classmethod
  2635. def create(cls, x: IRNode, new_size: Sequence[Expr]) -> IRNode: # type: ignore[override]
  2636. assert isinstance(new_size, Sequence), type(new_size)
  2637. old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
  2638. # Skip pointless views
  2639. if V.graph.sizevars.statically_known_list_equals(old_size, new_size):
  2640. return x
  2641. unbacked_symbols_in_sizes = False
  2642. if (
  2643. len(free_unbacked_symbols(old_size)) > 0
  2644. or len(free_unbacked_symbols(new_size)) > 0
  2645. ):
  2646. unbacked_symbols_in_sizes = True
  2647. if 0 in new_size:
  2648. def fake_reindex(index: Any) -> tuple[int, ...]:
  2649. return tuple([0] * len(old_size))
  2650. return cls(data=x, size=list(new_size), reindex=fake_reindex)
  2651. # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
  2652. elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes:
  2653. if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)):
  2654. # realize x; otherwise, the dynamic_reshape_indexer below will fail
  2655. # due to the size_hint's inability to process unbacked SymInts
  2656. # TODO: unbacked should not diverge from backed in determining striding
  2657. # Need to require contiguous here instead of realize, see:
  2658. # https://github.com/pytorch/pytorch/issues/145561
  2659. x = ExternKernel.require_contiguous(x)
  2660. storage, old_layout = as_storage_and_layout(x, want_contiguous=True)
  2661. new_layout = FixedLayout(
  2662. old_layout.device,
  2663. old_layout.dtype,
  2664. new_size,
  2665. FlexibleLayout.contiguous_strides(new_size),
  2666. old_layout.offset,
  2667. old_layout.is_pinned,
  2668. )
  2669. return ReinterpretView(data=storage, layout=new_layout)
  2670. reindex = cls.dynamic_reshape_indexer(old_size, new_size)
  2671. return cls(data=x, size=list(new_size), reindex=reindex)
  2672. @staticmethod
  2673. def resolve_negative_size(
  2674. old_size: Sequence[Expr], new_size: Sequence[Expr]
  2675. ) -> tuple[list[Expr], list[Expr]]:
  2676. new_size = [V.graph.sizevars.simplify(x) for x in new_size]
  2677. old_size = [V.graph.sizevars.simplify(x) for x in old_size]
  2678. new_size = list(new_size)
  2679. for i in range(len(new_size)):
  2680. if new_size[i] == -1:
  2681. new_size[i] = sympy.S.One
  2682. new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
  2683. break
  2684. V.graph.sizevars.check_equals(sympy_product(old_size), sympy_product(new_size))
  2685. return old_size, new_size
  2686. @classmethod
  2687. def dynamic_reshape_indexer(
  2688. cls,
  2689. old_size: Sequence[_IntLike],
  2690. new_size: Sequence[_IntLike],
  2691. dense_dim: Optional[int] = None,
  2692. ) -> Callable[[Sequence[_T]], Sequence[_V]]:
  2693. try:
  2694. reindex = cls._dynamic_reshape_indexer(old_size, new_size, dense_dim)
  2695. except (AssertionError, IndexError):
  2696. # optimistic algorithm failed, lets do a fallback
  2697. flat = [sympy_product(old_size)]
  2698. reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
  2699. reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
  2700. reindex = fuse_reindexing(reindex1, reindex2)
  2701. return reindex
  2702. @staticmethod
  2703. def _dynamic_reshape_indexer(
  2704. old_size: Sequence[Expr],
  2705. new_size: Sequence[Expr],
  2706. dense_dim: Optional[int] = None,
  2707. ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
  2708. """
  2709. Perform a reshape entirely by modifying indexing math
  2710. """
  2711. size_hint = V.graph.sizevars.size_hint
  2712. # TODO: These symbols may not escape, if they don't assert so and
  2713. # treat them as temporary
  2714. vars = [
  2715. sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size))
  2716. ]
  2717. stack_new = list(zip(vars, new_size))
  2718. stack_old = list(old_size)
  2719. # process the dense dim first
  2720. reordering_dense_dim = (
  2721. dense_dim is not None
  2722. and dense_dim != len(stack_old) - 1
  2723. and len(new_size) == 1
  2724. )
  2725. if reordering_dense_dim:
  2726. assert dense_dim is not None # mypy
  2727. old_dim = stack_old.pop(dense_dim)
  2728. stack_old.append(old_dim)
  2729. view_expr = []
  2730. while stack_new and stack_old:
  2731. size_old = stack_old.pop()
  2732. var, size_new = stack_new.pop()
  2733. if size_old == 1:
  2734. view_expr.append(sympy.S.Zero)
  2735. stack_new.append((var, size_new)) # re-add
  2736. elif size_new == 1:
  2737. stack_old.append(size_old) # re-add
  2738. elif size_hint(size_new) == size_hint(size_old):
  2739. view_expr.append(var)
  2740. V.graph.sizevars.check_equals(size_new, size_old)
  2741. elif size_hint(size_new) < size_hint(size_old):
  2742. while size_hint(size_new) < size_hint(size_old):
  2743. var2, size_new2 = stack_new.pop()
  2744. var = var2 * size_new + var
  2745. size_new = size_new * size_new2
  2746. view_expr.append(var)
  2747. V.graph.sizevars.check_equals(size_new, size_old)
  2748. elif size_hint(size_new) > size_hint(size_old):
  2749. divisor = sympy.S.One
  2750. modulus = size_old
  2751. view_expr.append(ModularIndexing(var, divisor, modulus))
  2752. divisor = divisor * modulus
  2753. while size_hint(size_new) > size_hint(size_old):
  2754. modulus = stack_old.pop()
  2755. view_expr.append(ModularIndexing(var, divisor, modulus))
  2756. divisor = divisor * modulus
  2757. size_old = size_old * modulus
  2758. V.graph.sizevars.check_equals(size_new, size_old)
  2759. else:
  2760. raise AssertionError
  2761. while stack_old:
  2762. size_old = stack_old.pop()
  2763. V.graph.sizevars.check_equals(size_old, 1)
  2764. view_expr.append(sympy.S.Zero)
  2765. while stack_new:
  2766. var, size_new = stack_new.pop()
  2767. V.graph.sizevars.check_equals(size_new, 1)
  2768. if dense_dim is not None and len(new_size) == 1:
  2769. view_expr.reverse()
  2770. # Move the last expression (dense dim) to its original position
  2771. dense_expr = view_expr.pop()
  2772. view_expr.insert(dense_dim, dense_expr)
  2773. else:
  2774. view_expr.reverse()
  2775. assert len(view_expr) == len(old_size)
  2776. def reindex(
  2777. index: Sequence[Expr],
  2778. ) -> Sequence[Expr]:
  2779. assert len(index) == len(vars), (len(index), len(vars))
  2780. replacements = dict(zip(vars, index))
  2781. return tuple(sympy_subs(x, replacements) for x in view_expr)
  2782. return reindex
  2783. @ir_dataclass
  2784. class ReinterpretView(BaseView):
  2785. """Pretend our storage has a different layout"""
  2786. layout: Layout
  2787. def __post_init__(self) -> None:
  2788. super().__post_init__()
  2789. if isinstance(self.data, BaseView):
  2790. object.__setattr__(self, "data", self.data.unwrap_view())
  2791. def __str__(self) -> str:
  2792. return self.str_helper(
  2793. [
  2794. self.data,
  2795. self.layout,
  2796. ]
  2797. )
  2798. __repr__ = __str__
  2799. def get_name(self) -> str:
  2800. return self.data.get_name()
  2801. def get_device(self) -> Optional[torch.device]:
  2802. return self.layout.device
  2803. def get_origin_node(self) -> Optional[torch.fx.Node]:
  2804. return None
  2805. @property
  2806. def dtype(self) -> torch.dtype:
  2807. return self.layout.dtype
  2808. def get_size(self) -> Sequence[Expr]:
  2809. return list(self.layout.size)
  2810. def get_stride(self) -> Sequence[Expr]:
  2811. return list(self.layout.stride)
  2812. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  2813. def loader(index: Sequence[Expr]) -> OpsValue:
  2814. indexer = self.layout.make_indexer()
  2815. tmp_loader = ops.load(self.get_name(), indexer(index))
  2816. if self.layout.dtype != self.data.dtype:
  2817. return ops.to_dtype_bitcast(tmp_loader, self.dtype, self.data.dtype)
  2818. else:
  2819. return tmp_loader
  2820. return loader
  2821. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  2822. return self.layout.make_indexer()
  2823. def get_layout(self) -> Layout:
  2824. return self.layout
  2825. def freeze_layout(self) -> None:
  2826. pass
  2827. @cache_on_self_and_args("ReinterpretView")
  2828. def get_free_symbol_uses(
  2829. self, unbacked_only: bool = False
  2830. ) -> OrderedSet[sympy.Symbol]:
  2831. return (
  2832. get_free_symbols(self.layout.size, unbacked_only)
  2833. | get_free_symbols(self.layout.stride, unbacked_only)
  2834. | get_free_symbols(self.layout.offset, unbacked_only)
  2835. )
  2836. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  2837. # reinterpret_tensor is similar to as_strided except:
  2838. # - offset is added to the existing offset (rather than replacing it)
  2839. # - view tracking is disabled similar to unsafe_view
  2840. return V.graph.wrapper_code.codegen_reinterpret_view(
  2841. self.data,
  2842. self.layout.size,
  2843. self.layout.stride,
  2844. self.layout.offset,
  2845. writer.writeline if writer is not None else V.graph.wrapper_code.writeline,
  2846. dtype=self.layout.dtype,
  2847. )
  2848. def num_reads(self) -> int:
  2849. return 1
  2850. @ir_dataclass
  2851. class DtypeView(BaseView):
  2852. """Pretend our storage has a different type"""
  2853. target_dtype: torch.dtype
  2854. @classmethod
  2855. def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView:
  2856. if is_storage_and_layout(x):
  2857. storage, old_layout = as_storage_and_layout(x)
  2858. new_layout = FixedLayout(
  2859. old_layout.device,
  2860. new_dtype,
  2861. old_layout.size,
  2862. old_layout.stride,
  2863. old_layout.offset,
  2864. old_layout.is_pinned,
  2865. )
  2866. return ReinterpretView(data=storage, layout=new_layout)
  2867. return DtypeView(data=x, target_dtype=new_dtype)
  2868. def __str__(self) -> str:
  2869. return self.str_helper([self.data, self.target_dtype])
  2870. __repr__ = __str__
  2871. @property
  2872. def dtype(self) -> torch.dtype:
  2873. return self.target_dtype
  2874. def get_size(self) -> Sequence[Expr]:
  2875. return self.data.get_size()
  2876. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  2877. inner = self.data.make_loader()
  2878. def loader(idx: Sequence[Expr]) -> OpsValue:
  2879. return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype)
  2880. return loader
  2881. class SliceView(View):
  2882. @classmethod
  2883. def normalize_start_end(
  2884. cls, x: IRNode, dim: int, start: int, end: int
  2885. ) -> tuple[int, int]:
  2886. """
  2887. Normalize start and end such that both are in the range
  2888. [0, x.get_size()[dim]] and start <= end.
  2889. """
  2890. sizevars = V.graph.sizevars
  2891. dim_size = x.get_size()[dim]
  2892. if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
  2893. min_func = sympy.Min
  2894. max_func = sympy.Max
  2895. else:
  2896. min_func = sizevars.evaluate_min
  2897. max_func = sizevars.evaluate_max
  2898. def clamp(x: Expr, lower: int, upper: int) -> Expr:
  2899. clamped_lower = (
  2900. x if sizevars.statically_known_geq(x, lower) else max_func(x, lower)
  2901. )
  2902. clamped_full = (
  2903. clamped_lower
  2904. if sizevars.statically_known_leq(clamped_lower, upper)
  2905. else min_func(clamped_lower, upper)
  2906. )
  2907. return clamped_full
  2908. def clamp_wrap(
  2909. val: Union[int, None], lower: int, upper: int, default: Union[Expr, int]
  2910. ) -> Union[Expr, int]:
  2911. if val is None:
  2912. # TODO(rec): can this really happen?
  2913. return default
  2914. val = cls.handle_negative_index(val, dim_size)
  2915. return clamp(val, lower, upper)
  2916. start = clamp_wrap(start, 0, dim_size, 0)
  2917. end = clamp_wrap(end, start, dim_size, dim_size)
  2918. return start, end
  2919. @classmethod
  2920. def create( # type: ignore[override]
  2921. cls,
  2922. x: IRNode,
  2923. dim: int,
  2924. start: int,
  2925. end: int,
  2926. step: int = 1,
  2927. clamp: bool = True,
  2928. ) -> IRNode:
  2929. step = sympy.expand(step)
  2930. assert isinstance(step, Expr) or step > 0, step
  2931. try:
  2932. if start == 0 and end >= 2**63 - 1 and step == 1:
  2933. return x
  2934. except TypeError:
  2935. pass
  2936. new_size = list(x.get_size())
  2937. # NB: Ordinarily we default to clamping.
  2938. # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid
  2939. # failing in this situation is ok, since invalid sizes could trigger silent errors.
  2940. if clamp:
  2941. start, end = cls.normalize_start_end(x, dim, start, end)
  2942. new_size[dim] = FloorDiv(end - start + (step - 1), step)
  2943. if is_storage_and_layout(x):
  2944. # Fast path
  2945. storage, old_layout = as_storage_and_layout(x)
  2946. new_stride = list(old_layout.stride)
  2947. new_stride[dim] = new_stride[dim] * step
  2948. new_layout = FixedLayout(
  2949. old_layout.device,
  2950. old_layout.dtype,
  2951. new_size,
  2952. new_stride,
  2953. old_layout.offset + old_layout.stride[dim] * start,
  2954. old_layout.is_pinned,
  2955. )
  2956. return ReinterpretView(data=storage, layout=new_layout)
  2957. def reindex(
  2958. index: Sequence[Expr],
  2959. ) -> Sequence[Expr]:
  2960. assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
  2961. index = list(index)
  2962. index[dim] = index[dim] * step + start
  2963. return index
  2964. # redirect to a generic view
  2965. return SliceView(data=x, size=new_size, reindex=reindex)
  2966. @ir_dataclass
  2967. class BaseConstant(IRNode):
  2968. dtype: torch.dtype
  2969. device: torch.device
  2970. def get_size(self) -> Sequence[Expr]:
  2971. return ()
  2972. def get_device(self) -> Optional[torch.device]:
  2973. return self.device
  2974. def get_origin_node(self) -> Optional[torch.fx.Node]:
  2975. return None
  2976. def get_reads(self) -> OrderedSet[Dep]:
  2977. return OrderedSet()
  2978. @ir_dataclass
  2979. class Constant(BaseConstant):
  2980. value: Any
  2981. dtype: torch.dtype
  2982. device: torch.device
  2983. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  2984. def loader(index: Sequence[Expr]) -> OpsValue:
  2985. return ops.constant(self.value, self.dtype)
  2986. return loader
  2987. def realize(self) -> Optional[str]:
  2988. pass
  2989. def constant_to_device(self, device: torch.device) -> IRNode:
  2990. return Constant(value=self.value, dtype=self.dtype, device=device)
  2991. @ir_dataclass
  2992. class IndexingConstant(BaseConstant):
  2993. index: Any
  2994. dtype: torch.dtype
  2995. device: torch.device
  2996. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  2997. def loader(index: Sequence[Expr]) -> OpsValue:
  2998. return ops.index_expr(self.index, self.dtype)
  2999. return loader
  3000. def constant_to_device(self, device: torch.device) -> IRNode:
  3001. return IndexingConstant(index=self.index, dtype=self.dtype, device=device)
  3002. def is_contiguous_strides_for_shape(
  3003. stride: Sequence[_IntLike], shape: Sequence[_IntLike]
  3004. ) -> bool:
  3005. expected_stride = 1
  3006. expected_stride_max = 1
  3007. for x, y in reversed(tuple(zip(shape, stride))):
  3008. if x == 1:
  3009. continue
  3010. if not V.graph.sizevars.statically_known_equals(
  3011. y, expected_stride
  3012. ) and not V.graph.sizevars.statically_known_equals(y, expected_stride_max):
  3013. return False
  3014. expected_stride_max *= sympy.Max(1, x)
  3015. expected_stride *= x
  3016. return True
  3017. def get_align_for_dtype(dtype: torch.dtype) -> int:
  3018. return config.padding_alignment_bytes // dtype.itemsize
  3019. class OutputSpec:
  3020. """Abstract base for Layout, MultiOutputLayout, NoneLayout.
  3021. Represents the memory layout of the output of an Operation."""
  3022. def get_device(self) -> Optional[torch.device]:
  3023. raise NotImplementedError(type(self).__name__)
  3024. def storage_size(self) -> int:
  3025. raise NotImplementedError(type(self).__name__)
  3026. def get_free_symbol_uses(
  3027. self, unbacked_only: bool = False
  3028. ) -> OrderedSet[sympy.Symbol]:
  3029. raise NotImplementedError(type(self).__name__)
  3030. @ir_dataclass
  3031. class Layout(OutputSpec):
  3032. """
  3033. Layout base class
  3034. Carries tensor meta-information including offset and
  3035. whether it is pinned.
  3036. """
  3037. def __init__(
  3038. self,
  3039. device: torch.device,
  3040. dtype: torch.dtype,
  3041. size: Sequence[Expr],
  3042. stride: Optional[Sequence[Expr]] = None,
  3043. offset: Expr = Integer(0),
  3044. is_pinned: bool = False,
  3045. ) -> None:
  3046. if stride is None:
  3047. stride = FlexibleLayout.contiguous_strides(size)
  3048. self.device = device
  3049. self.dtype = dtype
  3050. assert len(size) == len(stride), f"size={size}, stride={stride}"
  3051. assert all(isinstance(s, (Expr, int)) for s in size)
  3052. self._size = size
  3053. self._stride = stride
  3054. self._offset = offset
  3055. self.is_pinned = is_pinned
  3056. # is_pinned implies cpu
  3057. assert (not self.is_pinned) or (self.device.type == "cpu")
  3058. @property
  3059. def size(self) -> Sequence[Expr]:
  3060. return self._size
  3061. @size.setter
  3062. def size(self, value: Sequence[Expr]) -> None:
  3063. self._size = value
  3064. @property
  3065. def stride(self) -> Sequence[Expr]:
  3066. return self._stride
  3067. @stride.setter
  3068. def stride(self, value: Sequence[Expr]) -> None:
  3069. self._stride = value
  3070. @property
  3071. def offset(self) -> Expr:
  3072. return self._offset
  3073. @offset.setter
  3074. def offset(self, value: Expr) -> None:
  3075. self._offset = value
  3076. def __str__(self) -> str:
  3077. offset = ""
  3078. if self.offset != 0:
  3079. offset = f", offset={self.offset}"
  3080. device_index_str = "" if self.device.index is None else f":{self.device.index}"
  3081. is_pinned_str = ""
  3082. if self.is_pinned:
  3083. is_pinned_str = f", is_pinned={self.is_pinned}"
  3084. return (
  3085. f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, "
  3086. f"size={self.size}, stride={self.stride}{offset}{is_pinned_str})"
  3087. )
  3088. __repr__ = __str__
  3089. def get_device(self) -> torch.device:
  3090. return self.device
  3091. def get_example(self) -> torch.Tensor:
  3092. with V.fake_mode:
  3093. return torch.empty_strided(
  3094. convert_shape_to_symint(self.size),
  3095. convert_shape_to_symint(self.stride),
  3096. dtype=self.dtype,
  3097. device=self.device,
  3098. pin_memory=self.is_pinned,
  3099. )
  3100. def is_contiguous(self) -> bool:
  3101. return is_contiguous_strides_for_shape(self.stride, self.size)
  3102. @staticmethod
  3103. def is_channels_last_contiguous(
  3104. shape: Sequence[_IntLike], strides: Sequence[_IntLike]
  3105. ) -> bool:
  3106. ndim = len(shape)
  3107. if ndim not in [4, 5] or shape[1] == 1:
  3108. return False
  3109. for left, right, size in zip(
  3110. strides, make_channels_last_strides_for(shape), shape
  3111. ):
  3112. if size != 1 and left != right:
  3113. return False
  3114. return True
  3115. def is_transposed(self) -> bool:
  3116. for left, right, size in zip(
  3117. self.stride,
  3118. reversed(FlexibleLayout.contiguous_strides(list(reversed(self.size)))),
  3119. self.size,
  3120. ):
  3121. if size != 1 and left != right:
  3122. return False
  3123. return True
  3124. def is_stride_ordered(self, order: Sequence[int]) -> bool:
  3125. assert len(self.stride) == len(order)
  3126. # ignore dimensions of size 1, they dont affect layout
  3127. non_1_indices = [
  3128. i
  3129. for i, dim in enumerate(self.size)
  3130. if V.graph.sizevars.size_hint(dim, fallback=2) != 1
  3131. ]
  3132. stride = [self.stride[i] for i in non_1_indices]
  3133. order: Sequence[int] = [order[i] for i in non_1_indices]
  3134. def sorted_indices(arr: Sequence[int]) -> Sequence[int]:
  3135. sorted_arr = sorted(arr)
  3136. return [sorted_arr.index(element) for element in arr]
  3137. # since we may have removed dimensions, need to re-sort & re-index order
  3138. order = sorted_indices(order)
  3139. # reorder the stride given order
  3140. stride_ordered = [-1] * len(order)
  3141. for i in range(len(order)):
  3142. stride_ordered[order[i]] = stride[i]
  3143. # check if it is in ascending order
  3144. for i in range(len(order) - 1):
  3145. expr = stride_ordered[i] > stride_ordered[i + 1]
  3146. if not isinstance(expr, bool):
  3147. expr = V.graph._shape_env.evaluate_expr(
  3148. stride_ordered[i] > stride_ordered[i + 1], size_oblivious=True
  3149. )
  3150. if expr:
  3151. return False
  3152. return True
  3153. def is_channels_last_stride_ordered(self) -> bool:
  3154. # create channels_last order(NCHW, NCDHW, the C is the first order).
  3155. order = [0] + list(reversed(range(1, len(self.stride) - 1)))
  3156. order = [len(order)] + order
  3157. return self.is_stride_ordered(order)
  3158. @staticmethod
  3159. def _pad_strides(
  3160. in_strides: Sequence[int], size: Sequence[Expr], dtype: torch.dtype
  3161. ) -> Sequence[int]:
  3162. """
  3163. The padding does not change stride order but makes sure all strides larger
  3164. than the threshold are multiple of align.
  3165. """
  3166. align = get_align_for_dtype(dtype)
  3167. if len(in_strides) == 0:
  3168. return in_strides
  3169. if not config.pad_channels_last and Layout.is_channels_last_contiguous(
  3170. size, in_strides
  3171. ):
  3172. return in_strides
  3173. current_fx_node = V.get_current_node()
  3174. if hasattr(current_fx_node, "meta") and current_fx_node.meta.get(
  3175. "dislike_padding", False
  3176. ):
  3177. return in_strides
  3178. shape_env = V.graph._shape_env if hasattr(V.graph, "_shape_env") else None
  3179. def contains_unbacked_symints(expr: sympy.Expr | int) -> bool:
  3180. if shape_env is None:
  3181. return False
  3182. if not isinstance(expr, sympy.Expr):
  3183. return False
  3184. return any(shape_env.is_unbacked_symint(s) for s in expr.free_symbols)
  3185. # Skip padding the strides when it contains unbacked symints for now.
  3186. if shape_env and any(contains_unbacked_symints(s) for s in in_strides):
  3187. return in_strides
  3188. stride_order = get_stride_order(in_strides, shape_env)
  3189. fill_order = stride_order2fill_order(stride_order)
  3190. new_strides = [0 for _ in range(len(in_strides))]
  3191. # since we pad when the layout is flexible, we can decide the
  3192. # smallest stride to be 1.
  3193. new_strides[fill_order[0]] = 1
  3194. padded = False
  3195. for rank, idx in enumerate(fill_order[1:], start=1):
  3196. prev_idx = fill_order[rank - 1]
  3197. stride = new_strides[prev_idx] * size[prev_idx]
  3198. # Static stride and meets padding conditions OR
  3199. # Dynamic stride and config.pad_dynamic_shape=True
  3200. require_padding = (
  3201. isinstance(stride, (int, sympy.Integer))
  3202. and stride > config.padding_stride_threshold
  3203. and stride % align != 0
  3204. ) or (isinstance(stride, sympy.Expr) and config.pad_dynamic_shapes)
  3205. new_strides[idx] = stride
  3206. if require_padding:
  3207. new_strides[idx] = ceildiv(stride, align) * align
  3208. padded = True
  3209. if not padded:
  3210. # Consider a tensor with shape [256, 1, 5, 5]
  3211. # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides
  3212. # [25, 25, 5, 1].
  3213. return in_strides
  3214. metrics.num_comprehensive_padding += 1
  3215. return new_strides
  3216. def pad_strides(self) -> None:
  3217. assert isinstance(self, FlexibleLayout), type(self)
  3218. assert self.stride is not None
  3219. self.stride = self._pad_strides(self.stride, self.size, self.dtype)
  3220. def should_pad_strides(self) -> bool:
  3221. return config.comprehensive_padding and isinstance(self, FlexibleLayout)
  3222. def as_fixed(self) -> FixedLayout:
  3223. if isinstance(self, FixedLayout):
  3224. return self
  3225. if self.should_pad_strides():
  3226. self.pad_strides()
  3227. return FixedLayout(
  3228. self.device,
  3229. self.dtype,
  3230. self.size,
  3231. self.stride,
  3232. self.offset,
  3233. self.is_pinned,
  3234. )
  3235. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3236. assert FlexibleLayout.allow_indexing, (
  3237. f"convert {type(self).__name__} to FixedLayout first"
  3238. )
  3239. return self.as_fixed().make_indexer()
  3240. def __eq__(self, other: object) -> bool:
  3241. return (
  3242. isinstance(other, Layout)
  3243. and self.device == other.device
  3244. and self.dtype == other.dtype
  3245. and self.size == other.size
  3246. and self.stride == other.stride
  3247. and self.offset == other.offset
  3248. and self.is_pinned == other.is_pinned
  3249. )
  3250. def storage_size(self) -> Expr:
  3251. return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
  3252. @cache_on_self_and_args("Layout")
  3253. def get_free_symbol_uses(
  3254. self, unbacked_only: bool = False
  3255. ) -> OrderedSet[sympy.Symbol]:
  3256. return (
  3257. get_free_symbols(self.size, unbacked_only)
  3258. | get_free_symbols(self.stride, unbacked_only)
  3259. | get_free_symbols(self.offset, unbacked_only)
  3260. )
  3261. class FixedLayout(Layout):
  3262. """A Tensor layout we cannot change"""
  3263. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3264. """A closure containing math to read a given element"""
  3265. return _fixed_indexer(self.size, self.stride, self.offset)
  3266. class FlexibleLayout(Layout):
  3267. """
  3268. A Tensor layout that we are allowed to change
  3269. Assumption: layout change should NOT add or remove free symbols
  3270. """
  3271. allow_indexing = False
  3272. # WARNING! This doesn't handle zero size tensors correctly
  3273. @staticmethod
  3274. def contiguous_strides(sizes: Sequence[int]) -> list[Expr]:
  3275. if len(sizes) == 0:
  3276. return []
  3277. reversed_strides = [sympy.S.One]
  3278. for size in reversed(sizes[1:]):
  3279. reversed_strides.append(size * reversed_strides[-1])
  3280. return list(reversed(reversed_strides))
  3281. @staticmethod
  3282. def fill_ordered(sizes: Sequence[int], order: Sequence[int]) -> list[Expr]:
  3283. """
  3284. Create a stride based on the order the dimensions should be filled in.
  3285. In this format, channels last would be:
  3286. [1, 3, 2, 0]
  3287. """
  3288. assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order)
  3289. next_stride = sympy.S.One
  3290. strides = [None] * len(order)
  3291. for i in order:
  3292. strides[i] = next_stride
  3293. next_stride = next_stride * sizes[i]
  3294. return strides
  3295. @staticmethod
  3296. def stride_ordered(sizes: Sequence[int], order: Sequence[int]) -> Sequence[Expr]:
  3297. """
  3298. Create a stride based on the sorted order of a permuted range.
  3299. In this format, channels last would be:
  3300. [3, 0, 2, 1]
  3301. """
  3302. assert OrderedSet(range(len(sizes))) == OrderedSet(order)
  3303. fill_order = stride_order2fill_order(order)
  3304. return FlexibleLayout.fill_ordered(sizes, fill_order)
  3305. @staticmethod
  3306. def stride_ordered_for_memory_format(
  3307. sizes: Sequence[int], memory_format: torch.memory_format
  3308. ) -> Sequence[Expr]:
  3309. """
  3310. Create a stride based on a memory format.
  3311. Memory format is translasted into a stride order,
  3312. so channels_last is the same as:
  3313. FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1])
  3314. This interface does not support memory_format `torch.preserve_format`
  3315. which should be used to deduce a format from another source
  3316. """
  3317. if memory_format == torch.channels_last:
  3318. return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER)
  3319. elif memory_format == torch.channels_last_3d:
  3320. return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER)
  3321. elif memory_format == torch.contiguous_format:
  3322. return FlexibleLayout.contiguous_strides(sizes)
  3323. else:
  3324. log.debug(
  3325. "stride_ordered_for_memory_format, unsuppored memory_format: %s",
  3326. memory_format,
  3327. )
  3328. raise NotImplementedError
  3329. @staticmethod
  3330. def same_ordered(
  3331. sizes: Sequence[int], stride: Sequence[_IntLike]
  3332. ) -> Sequence[Expr]:
  3333. """
  3334. Create a stride that has the same stride order as given stride
  3335. For example, if given stride is [1000, 1, 100, 10],
  3336. the fill order should be [1, 3, 2, 0]
  3337. """
  3338. assert len(sizes) == len(stride)
  3339. stride = [V.graph.sizevars.size_hint_or_throw(x) for x in stride]
  3340. fill_order = sorted(range(len(stride)), key=stride.__getitem__)
  3341. return FlexibleLayout.fill_ordered(sizes, fill_order)
  3342. @property
  3343. def size(self) -> Sequence[Expr]:
  3344. return self._size
  3345. @size.setter
  3346. def size(self, value: Sequence[Expr]) -> None:
  3347. self.assert_free_symbol_uses_unchanged("size", value)
  3348. self._size = value
  3349. @property
  3350. def stride(self) -> Sequence[Expr]:
  3351. return self._stride
  3352. @stride.setter
  3353. def stride(self, value: Sequence[Expr]) -> None:
  3354. self.assert_free_symbol_uses_unchanged("stride", value)
  3355. self._stride = value
  3356. @property
  3357. def offset(self) -> Expr:
  3358. return self._offset
  3359. @offset.setter
  3360. def offset(self, value: Expr) -> None:
  3361. self.assert_free_symbol_uses_unchanged("offset", value)
  3362. self._offset = value
  3363. def as_stride_order(
  3364. self, order: Sequence[int], allow_padding: bool = False
  3365. ) -> FixedLayout:
  3366. new_stride = self.stride_ordered(self.size, order)
  3367. if self.should_pad_strides() and allow_padding:
  3368. new_stride = self._pad_strides(new_stride, self.size, self.dtype)
  3369. return FixedLayout(
  3370. self.device,
  3371. self.dtype,
  3372. self.size,
  3373. new_stride,
  3374. self.offset,
  3375. self.is_pinned,
  3376. )
  3377. def as_exact_strides(
  3378. self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
  3379. ) -> FixedLayout:
  3380. new_stride = exact_strides
  3381. if self.should_pad_strides() and allow_padding:
  3382. new_stride = self._pad_strides(new_stride, self.size, self.dtype)
  3383. return FixedLayout(
  3384. self.device,
  3385. self.dtype,
  3386. self.size,
  3387. new_stride,
  3388. self.offset,
  3389. self.is_pinned,
  3390. )
  3391. def as_fill_order(self, order: Sequence[int]) -> FixedLayout:
  3392. new_stride: Sequence[int] = self.fill_ordered(self.size, order)
  3393. if self.should_pad_strides():
  3394. new_stride = self._pad_strides(new_stride, self.size, self.dtype)
  3395. return FixedLayout(
  3396. self.device,
  3397. self.dtype,
  3398. self.size,
  3399. new_stride,
  3400. self.offset,
  3401. self.is_pinned,
  3402. )
  3403. def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout:
  3404. new_stride = self.same_ordered(self.size, stride)
  3405. if self.should_pad_strides():
  3406. new_stride = self._pad_strides(new_stride, self.size, self.dtype)
  3407. return FixedLayout(
  3408. self.device,
  3409. self.dtype,
  3410. self.size,
  3411. new_stride,
  3412. self.offset,
  3413. self.is_pinned,
  3414. )
  3415. def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]:
  3416. initial_free_symbols = {}
  3417. for name in ["size", "stride", "offset"]:
  3418. for unbacked_only in [True, False]:
  3419. key = (name, unbacked_only)
  3420. initial_free_symbols[key] = OrderedSet(
  3421. get_free_symbols(getattr(self, name), unbacked_only)
  3422. )
  3423. return initial_free_symbols
  3424. def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None:
  3425. for unbacked_only in [True, False]:
  3426. old_free_symbols = self.initial_free_symbols[(name, unbacked_only)]
  3427. new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only))
  3428. assert new_free_symbols == old_free_symbols, (
  3429. f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}"
  3430. )
  3431. def __init__(
  3432. self,
  3433. device: torch.device,
  3434. dtype: torch.dtype,
  3435. size: Sequence[Expr],
  3436. stride_order: Optional[Sequence[Union[int, Integer]]] = None,
  3437. is_pinned: bool = False,
  3438. ) -> None:
  3439. if stride_order:
  3440. strides = FlexibleLayout.fill_ordered(size, stride_order)
  3441. else:
  3442. strides = FlexibleLayout.contiguous_strides(size)
  3443. super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
  3444. # record the initial free symbols to check that we do not add new free symbols
  3445. # later when modifying sizes, strides, and offsets.
  3446. self.initial_free_symbols = self.get_initial_free_symbol_uses()
  3447. class NonOwningLayout(Layout):
  3448. """Is a view into the storage of another tensor"""
  3449. def __init__(self, view: Union[BaseView, TensorBox]) -> None:
  3450. layout = view.get_layout()
  3451. super().__init__(
  3452. layout.device,
  3453. layout.dtype,
  3454. layout.size,
  3455. layout.stride,
  3456. )
  3457. self.view = view
  3458. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3459. return self.as_fixed().make_indexer()
  3460. def maybe_guard_aligned(self) -> bool:
  3461. offset = self.view.get_layout().offset
  3462. if offset == 0:
  3463. return True
  3464. from .utils import ALIGNMENT
  3465. return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
  3466. @cache_on_self_and_args("NonOwningLayout")
  3467. def get_free_symbol_uses(
  3468. self, unbacked_only: bool = False
  3469. ) -> OrderedSet[sympy.Symbol]:
  3470. assert isinstance(self.view, ReinterpretView)
  3471. box = self.view.data
  3472. assert isinstance(box, StorageBox), type(box)
  3473. input_buffer = box.data
  3474. assert isinstance(input_buffer, Buffer), type(box)
  3475. return input_buffer.layout.get_free_symbol_uses(unbacked_only)
  3476. class CommBufferType(Enum):
  3477. SYMM_MEM = "symm_mem"
  3478. class CommBufferLayout(FixedLayout):
  3479. """
  3480. A layout that signifies the buffer is a comm buffer.
  3481. In terms of striding, the layout is identical to `FixedLayout`.
  3482. Buffers with this layout do not participate in in-place reuse - it can be
  3483. neither the source nor the target for in-place reuse.
  3484. For detailed motivation and usage of this layout, see
  3485. NOTE [lowering-time collective optimization].
  3486. """
  3487. comm_buffer_type: CommBufferType
  3488. group_name: str
  3489. def __init__(
  3490. self,
  3491. layout: FlexibleLayout,
  3492. comm_buffer_type: CommBufferType,
  3493. group_name: str,
  3494. ):
  3495. if not isinstance(layout, FlexibleLayout):
  3496. raise AssertionError(
  3497. "A `CommBufferLayout` can only be initialized with "
  3498. f"a `FlexibleLayout` (got {layout})."
  3499. )
  3500. fixed = layout.as_fixed()
  3501. super().__init__(
  3502. device=fixed.device,
  3503. dtype=fixed.dtype,
  3504. size=fixed.size,
  3505. stride=fixed.stride,
  3506. offset=fixed.offset,
  3507. is_pinned=fixed.is_pinned,
  3508. )
  3509. self.comm_buffer_type = comm_buffer_type
  3510. self.group_name = group_name
  3511. @ir_dataclass
  3512. class NoneLayout(OutputSpec):
  3513. # This is janky, I figured out what fields to populate by just running
  3514. # the model I was interested in and adding properties/methods as needed.
  3515. # This doesn't inherit from Layout because Layout assumes you have stuff
  3516. # like sizes, but I don't really have anything here.
  3517. #
  3518. # If you have an ir.Node with NoneLayout, you probably need to setup
  3519. # dependencies manually in scheduler
  3520. device: Optional[torch.device]
  3521. size: list[int] = dataclasses.field(default_factory=lambda: [0])
  3522. stride: list[int] = dataclasses.field(default_factory=lambda: [0])
  3523. def storage_size(self) -> int:
  3524. return 0
  3525. def as_fixed(self) -> OutputSpec:
  3526. return self
  3527. def get_device(self) -> Optional[torch.device]:
  3528. return self.device
  3529. class MutationLayoutSHOULDREMOVE(Layout):
  3530. def __init__(self, target: IRNode) -> None:
  3531. super().__init__(
  3532. target.get_device_or_error(),
  3533. target.get_dtype(),
  3534. target.get_size(),
  3535. None,
  3536. )
  3537. self.target = target
  3538. name = self.get_buffer().get_name()
  3539. V.graph.mark_buffer_mutated(name)
  3540. @property
  3541. def stride(self) -> Sequence[Expr]: # type: ignore[override]
  3542. return self.real_layout().stride
  3543. @stride.setter # type: ignore[override]
  3544. def stride(self, value: Never) -> None:
  3545. pass # ignore setting of stride
  3546. def storage_size(self) -> Expr:
  3547. return self.real_layout().storage_size()
  3548. def get_buffer(self) -> Buffer:
  3549. def unwrap_views(target: Any) -> Any:
  3550. if isinstance(target, MutationLayoutSHOULDREMOVE):
  3551. return unwrap_views(target.target)
  3552. if isinstance(target, BaseView):
  3553. return unwrap_views(target.unwrap_view())
  3554. if isinstance(target, MutableBox):
  3555. return unwrap_views(target.data)
  3556. return target
  3557. result = unwrap_views(self.target)
  3558. assert isinstance(result, Buffer), type(result)
  3559. return result
  3560. def real_layout(self) -> Layout:
  3561. layout = self.get_buffer().layout
  3562. assert isinstance(layout, Layout)
  3563. return layout
  3564. @classmethod
  3565. def realize_into(
  3566. cls, src: IRNode, dst: IRNode, unsafe_alias: bool = False
  3567. ) -> IRNode:
  3568. dst.realize()
  3569. # NOTE: We must realize users of `dst` before we realize `src`, since
  3570. # realization order determines scheduling order. Otherwise, src's
  3571. # mutation would be scheduled before the existing users of dst!
  3572. V.graph.mark_buffer_mutated(dst.get_name())
  3573. if isinstance(src, TensorBox):
  3574. src = src.data
  3575. # We copy the contents of src into dst. In most cases this should
  3576. # be fused into a single kernel by the scheduler.
  3577. # NOTE: We cannot change src's layout to mutate dst directly as this
  3578. # would alias src to dst, which is not correct as further mutations to
  3579. # dst would effect users of src. However if there are no more users of
  3580. # dst, we can alias src to dst.
  3581. src.realize_hint()
  3582. if not unsafe_alias:
  3583. node = Pointwise.create(
  3584. device=src.get_device(),
  3585. dtype=src.get_dtype(),
  3586. inner_fn=src.make_loader(),
  3587. ranges=[
  3588. V.graph.sizevars.check_equals_and_simplify(a, b)
  3589. for a, b in zip(src.get_size(), dst.get_size())
  3590. ],
  3591. )
  3592. assert isinstance(node, (BaseView, MutableBox))
  3593. src = node.data
  3594. src.realize()
  3595. assert hasattr(src, "data"), src
  3596. assert isinstance(src.data.layout, FlexibleLayout), type(src.data.layout)
  3597. src.data.layout = MutationLayoutSHOULDREMOVE(dst)
  3598. return src.data
  3599. def as_fixed(self) -> Self: # type: ignore[override]
  3600. return self
  3601. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3602. return self.target.make_indexer()
  3603. @ir_dataclass(frozen=False)
  3604. class Buffer(IRNode, CodegenSymbol):
  3605. # Name is sometimes None; e.g., ForceInPlace, where there isn't
  3606. # a meaningful name
  3607. name: Optional[str]
  3608. layout: OutputSpec
  3609. # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly,
  3610. # MultiOutput does NOT define this!
  3611. def __post_init__(self) -> None:
  3612. super().__post_init__()
  3613. self._post_init_setattr("origin_node", None)
  3614. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  3615. return self.get_layout().make_indexer()
  3616. def get_name(self) -> str:
  3617. assert self.name, self
  3618. return self.name
  3619. def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
  3620. if isinstance(self.layout, Layout):
  3621. return self.layout.get_example()
  3622. raise NotImplementedError(type(self.layout).__name__)
  3623. def get_device(self) -> Optional[torch.device]:
  3624. return self.get_output_spec().get_device()
  3625. def get_defining_op(self) -> Optional[Operation]:
  3626. return None
  3627. @property
  3628. def dtype(self) -> torch.dtype:
  3629. return self.get_layout().dtype
  3630. def get_size(self) -> Sequence[Expr]:
  3631. return [*self.get_layout().size]
  3632. def get_stride(self) -> list[Expr]:
  3633. return [*self.get_layout().stride]
  3634. def get_offset(self) -> Expr:
  3635. return self.get_layout().offset
  3636. def get_layout(self) -> Layout:
  3637. if isinstance(self.layout, Layout):
  3638. return self.layout
  3639. raise NotImplementedError(type(self.layout).__name__)
  3640. def get_output_spec(self) -> OutputSpec:
  3641. return self.layout
  3642. def get_storage_numel(self) -> int:
  3643. return self.get_numel()
  3644. def get_is_pinned(self) -> bool:
  3645. return self.get_layout().is_pinned
  3646. def freeze_layout(self) -> None:
  3647. if isinstance(self.layout, Layout) and not isinstance(
  3648. self.layout, NonOwningLayout
  3649. ):
  3650. self.layout = self.layout.as_fixed()
  3651. def freeze_layout_with_stride_order(
  3652. self, order: Sequence[int], allow_padding: bool = False
  3653. ) -> None:
  3654. assert isinstance(self.layout, FlexibleLayout), type(self.layout)
  3655. self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding)
  3656. def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
  3657. assert isinstance(self.layout, FlexibleLayout), type(self.layout)
  3658. self.layout = self.layout.as_fill_order(order)
  3659. def freeze_layout_with_same_order(self, stride: Sequence[int]) -> None:
  3660. assert isinstance(self.layout, FlexibleLayout), type(self.layout)
  3661. self.layout = self.layout.as_same_order(stride)
  3662. def freeze_layout_with_exact_strides(
  3663. self, exact_strides: Sequence[int], allow_padding: bool = False
  3664. ) -> None:
  3665. assert isinstance(self.layout, FlexibleLayout), type(self.layout)
  3666. self.layout = self.layout.as_exact_strides(
  3667. exact_strides, allow_padding=allow_padding
  3668. )
  3669. def is_zero_elements(self) -> bool:
  3670. return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
  3671. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  3672. # Loading from a zero-element buffer is a no-op
  3673. if self.is_zero_elements():
  3674. return partial(nop_loader_fn, dtype=self.get_dtype())
  3675. def loader(index: Sequence[Expr]) -> OpsValue:
  3676. indexer = self.make_indexer()
  3677. return ops.load(self.name or "unnamed", indexer(index))
  3678. return loader
  3679. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  3680. return self.get_name()
  3681. def decide_layout(self) -> None:
  3682. pass
  3683. def get_inputs_that_alias_output(self) -> Sequence[str]:
  3684. if isinstance(self.layout, NonOwningLayout):
  3685. return [self.layout.view.get_name()]
  3686. return ()
  3687. def get_mutation_names(self) -> Sequence[str]:
  3688. if isinstance(self.layout, MutationLayoutSHOULDREMOVE):
  3689. return [self.layout.target.get_name()]
  3690. return ()
  3691. def get_read_names(self) -> OrderedSet[str]:
  3692. return OrderedSet([self.get_name()])
  3693. @cache_on_self_and_args("Buffer")
  3694. def get_free_symbol_uses(
  3695. self, unbacked_only: bool = False
  3696. ) -> OrderedSet[sympy.Symbol]:
  3697. return OrderedSet()
  3698. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  3699. return OrderedSet()
  3700. def realize(self) -> Optional[str]:
  3701. pass
  3702. def should_allocate(self) -> bool:
  3703. # Returns False by default.
  3704. return False
  3705. @ir_dataclass(frozen=False)
  3706. class OperationBuffer(Buffer, Operation):
  3707. # An operation that produces a single output buffer
  3708. def get_outputs(self) -> list[Buffer]:
  3709. return [self]
  3710. def get_defining_op(self) -> Operation:
  3711. return self
  3712. # Skip implementation in Buffer
  3713. get_operation_name = Operation.get_operation_name
  3714. def __post_init__(self) -> None:
  3715. Buffer.__post_init__(self)
  3716. Operation.__post_init__(self)
  3717. class InputBuffer(Buffer):
  3718. def num_reads(self) -> int:
  3719. return 1
  3720. class DonatedBuffer(InputBuffer):
  3721. """
  3722. Represents a donated buffer which is a saved tensor that is not alias to any
  3723. fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace
  3724. reuse the input tensor memory during backward since it might be used in another
  3725. function. However, donated buffer can be inplace reused during backward
  3726. to save memory.
  3727. """
  3728. class ConstantBuffer(InputBuffer):
  3729. override_device: Optional[torch.device] = None
  3730. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  3731. def loader(index: Sequence[Expr]) -> OpsValue:
  3732. indexer = self.get_layout().make_indexer()
  3733. return ops.load(
  3734. V.graph.constant_name(self.get_name(), self.override_device),
  3735. indexer(index),
  3736. )
  3737. return loader
  3738. def constant_to_device(self, device: torch.device) -> IRNode:
  3739. return ConstantBuffer(
  3740. name=V.graph.constant_name(self.get_name(), device), layout=self.layout
  3741. )
  3742. @ir_dataclass
  3743. class NoneAsConstantBuffer(IRNode):
  3744. def get_reads(self) -> OrderedSet[Dep]:
  3745. return OrderedSet()
  3746. @cache_on_self_and_args("NoneAsConstantBuffer")
  3747. def get_free_symbol_uses(
  3748. self, unbacked_only: bool = False
  3749. ) -> OrderedSet[sympy.Symbol]:
  3750. return OrderedSet()
  3751. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  3752. return V.graph.wrapper_code.none_str
  3753. def get_output_spec(self) -> OutputSpec:
  3754. return NoneLayout(device=None)
  3755. def has_tensor_output(self) -> bool:
  3756. return False
  3757. @ir_dataclass
  3758. class ShapeAsConstantBuffer(IRNode):
  3759. expr: Expr
  3760. @cache_on_self_and_args("ShapeAsConstantBuffer")
  3761. def get_free_symbol_uses(
  3762. self, unbacked_only: bool = False
  3763. ) -> OrderedSet[sympy.Symbol]:
  3764. return get_free_symbols(self.expr, unbacked_only)
  3765. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  3766. return V.graph.wrapper_code.codegen_sizevar(self.expr)
  3767. def has_tensor_output(self) -> bool:
  3768. return False
  3769. @ir_dataclass(frozen=False)
  3770. class ComputedBuffer(OperationBuffer):
  3771. """
  3772. Represents a buffer that is computed during kernel execution rather than being an input.
  3773. """
  3774. data: Loops
  3775. _force_realize: ClassVar[bool] = False
  3776. @staticmethod
  3777. @contextlib.contextmanager
  3778. def force_realize() -> Iterator[None]:
  3779. old_value = ComputedBuffer._force_realize
  3780. try:
  3781. ComputedBuffer._force_realize = True
  3782. yield
  3783. finally:
  3784. ComputedBuffer._force_realize = old_value
  3785. def get_computed_buffer_name(self) -> Optional[str]:
  3786. """
  3787. Returns self.name if it exists, otherwise returns the name of the data node if that exists.
  3788. If neither exist, returns None.
  3789. """
  3790. if self.name is not None:
  3791. return self.name
  3792. if hasattr(self.data, "name"):
  3793. return self.data.name
  3794. return None
  3795. def num_reads(self) -> int:
  3796. return self.data.num_reads()
  3797. def get_reads(self) -> OrderedSet[Dep]:
  3798. return self.data.get_reads()
  3799. def get_read_names(self) -> OrderedSet[str]:
  3800. return self.data.get_read_names()
  3801. def get_read_writes(self) -> dependencies.ReadWrites:
  3802. if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)):
  3803. return dependencies.ReadWrites(
  3804. reads=OrderedSet(),
  3805. writes=OrderedSet(),
  3806. index_exprs=OrderedSet(),
  3807. )
  3808. with patch.object(FlexibleLayout, "allow_indexing", True):
  3809. if self.data.get_reduction_type():
  3810. return extract_read_writes(
  3811. self.get_store_function(),
  3812. self.data.get_pointwise_size(),
  3813. self.data.get_reduction_size(),
  3814. )
  3815. else:
  3816. return extract_read_writes(
  3817. self.get_store_function(),
  3818. self.data.get_size(),
  3819. )
  3820. @cache_on_self_and_args("ComputedBuffer")
  3821. def get_free_symbol_uses(
  3822. self, unbacked_only: bool = False
  3823. ) -> OrderedSet[sympy.Symbol]:
  3824. # Ordinarily, we'd like to just peek at the arguments list,
  3825. # but ComputedBuffers have no argument list.
  3826. #
  3827. # Morally, this logic needs to be synchronized with the
  3828. # KernelArgs.size calls, which are responsible for making symbols make
  3829. # there way as kernel arguments (and it is precisely passing in one of
  3830. # those symbols that establishes a dependency). However, we haven't
  3831. # started codegen yet so we can't directly reuse that logic.
  3832. #
  3833. # One thing you might wonder is if this is enough for a ComputedBuffer
  3834. # denoting a reduction over i0. Empirically, it is enough, but for an
  3835. # unusual reason: we only need accurate dependencies for item() call,
  3836. # but it's impossible to end up with a reduction over i0 from an
  3837. # item() call without a regular non-reduction buffer first.
  3838. result = self.layout.get_free_symbol_uses(
  3839. unbacked_only
  3840. ) | self.data.get_free_symbol_uses(unbacked_only)
  3841. if self.has_store_function() and isinstance(
  3842. self.get_store_function(), LoopBody
  3843. ):
  3844. result |= self.get_read_writes().get_free_symbol_uses(unbacked_only)
  3845. return result
  3846. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  3847. if (
  3848. not self.get_reduction_type()
  3849. and self.name not in V.graph.mutated_buffers
  3850. and self.num_reads() == 0
  3851. and not self._force_realize
  3852. ):
  3853. # inline this op rather than generating ops.load()
  3854. return self.data.make_loader()
  3855. return super().make_loader()
  3856. def has_store_function(self) -> bool:
  3857. return isinstance(self.data, (Reduction, Scan, Sort, Pointwise))
  3858. def get_store_function(self) -> Callable[..., None]:
  3859. indexer = self.get_layout().as_fixed().make_indexer()
  3860. if isinstance(self.data, (Reduction, Scan, Sort)):
  3861. return partial(self.data.store_reduction, self.name, indexer)
  3862. else:
  3863. assert isinstance(self.data, Pointwise), type(self.data)
  3864. return partial(self.data.store_output, self.name, indexer)
  3865. def get_fill_order(self) -> Optional[list[int]]:
  3866. """
  3867. If our layout is still flexible, try to determine the stride order based on stride orders of reads.
  3868. TODO(jansel): A better algorithm here would look at downstream consumers of this
  3869. value and try to do global graph-level layout optimization.
  3870. This is also something just begging to be autotuned.
  3871. """
  3872. if isinstance(self.layout, FlexibleLayout):
  3873. (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
  3874. self.data.get_pointwise_size(), self.data.get_reduction_size()
  3875. )
  3876. reads = self.get_read_writes().reads
  3877. # only consider reads to buffer of same size
  3878. # ignore StarDeps because they don't contribute stride information
  3879. assert all(
  3880. isinstance(r, (dependencies.StarDep, dependencies.MemoryDep))
  3881. for r in reads
  3882. )
  3883. reads = [
  3884. sympy_subs(r.index, {v: sympy.S.Zero for v in reduction_vars if v != 0})
  3885. for r in reads
  3886. if isinstance(r, dependencies.MemoryDep)
  3887. ]
  3888. if reads:
  3889. if isinstance(self.data, (Scan, Sort)):
  3890. indices = self.data.reindex(index_vars, reduction_vars)
  3891. else:
  3892. indices = index_vars
  3893. stride_lengths = [
  3894. V.graph.sizevars.stride_hints(expr, indices) for expr in reads
  3895. ]
  3896. from .scheduler import pick_loop_order
  3897. return pick_loop_order(stride_lengths, self.get_size())
  3898. return None
  3899. def decide_layout(self) -> None:
  3900. if isinstance(self.layout, FlexibleLayout):
  3901. order = self.get_fill_order()
  3902. if order:
  3903. self.freeze_layout_with_fill_order(order)
  3904. else:
  3905. self.freeze_layout()
  3906. @cache_on_self
  3907. def get_default_sizes_body(
  3908. self,
  3909. ) -> tuple[
  3910. tuple[list[Expr], list[Expr]],
  3911. LoopBody,
  3912. tuple[list[Expr], list[Expr]],
  3913. ]:
  3914. args, var_ranges = dependencies.index_vars_squeeze(
  3915. self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
  3916. )
  3917. with patch.object(ConstantBuffer, "override_device", self.get_device()):
  3918. body = LoopBody(
  3919. self.get_store_function(),
  3920. (args if self.get_reduction_type() else args[:1]),
  3921. var_ranges,
  3922. *args,
  3923. )
  3924. index_vars = []
  3925. reduce_vars: list[Any] = []
  3926. index_size = []
  3927. reduce_size = []
  3928. for v, s in var_ranges.items():
  3929. if v in args[0]:
  3930. assert not reduce_vars
  3931. index_vars.append(v)
  3932. index_size.append(s)
  3933. else:
  3934. assert v in args[1]
  3935. reduce_vars.append(v)
  3936. reduce_size.append(s)
  3937. return (index_size, reduce_size), body, (index_vars, reduce_vars)
  3938. def simplify_and_reorder(
  3939. self,
  3940. extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
  3941. recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
  3942. ) -> tuple[tuple[list[Expr], list[Expr]], Optional[LoopBody]]:
  3943. """
  3944. This is a main place where we do loop transformations in a
  3945. backend-agnostic way.
  3946. Here we:
  3947. 1) Remove any 1 dimensions
  3948. 2) Fuse contiguous dimensions together
  3949. 3) Reorder dimensions based on stride orders
  3950. Optional argument extra_indexing_constraints can be used to append additional
  3951. indexing expressions to existing ones derived from buffer's body. This can be useful
  3952. to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...)
  3953. on CPU by preventing indexing simplifications and obtaining index/reduce ranges for
  3954. the scheduler node compatible with other nodes.
  3955. Optional argument recompute_sizes_body_func can be used to recompute sizes and body
  3956. on the default body. This can be useful to append additional loop transformations.
  3957. """
  3958. (
  3959. (index_size, reduce_size),
  3960. body,
  3961. (index_vars, reduce_vars),
  3962. ) = self.get_default_sizes_body()
  3963. if recompute_sizes_body_func:
  3964. (
  3965. (index_size, reduce_size),
  3966. body,
  3967. (index_vars, reduce_vars),
  3968. ) = recompute_sizes_body_func(
  3969. (index_size, reduce_size), body, (index_vars, reduce_vars)
  3970. )
  3971. index_formulas = [*body.indexing_exprs.values()]
  3972. if extra_indexing_constraints is not None:
  3973. assert (
  3974. isinstance(extra_indexing_constraints, tuple)
  3975. and len(extra_indexing_constraints) == 2
  3976. )
  3977. extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints
  3978. assert isinstance(extra_indexing_ranges, dict), type(extra_indexing_ranges)
  3979. assert isinstance(extra_indexing_expr, list), type(extra_indexing_expr)
  3980. assert all(isinstance(f, Expr) for f in extra_indexing_expr)
  3981. expected_var_ranges = body.var_ranges
  3982. assert expected_var_ranges == extra_indexing_ranges, (
  3983. expected_var_ranges,
  3984. extra_indexing_ranges,
  3985. )
  3986. # remove already existing expressions
  3987. extra_indexing_expr = [
  3988. e for e in extra_indexing_expr if e not in index_formulas
  3989. ]
  3990. index_formulas += extra_indexing_expr
  3991. memory_addrs = [*body.get_write_exprs()]
  3992. if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER):
  3993. memory_addrs.extend(body.get_read_exprs())
  3994. def simplify_and_reorder(
  3995. x_vars: Sequence[sympy.Symbol],
  3996. support_vars: Sequence[sympy.Symbol],
  3997. sizes: Sequence[int],
  3998. simplify_loops: bool,
  3999. ) -> tuple[
  4000. list[int],
  4001. Callable[[Sequence[int]], Sequence[int]],
  4002. Callable[[Sequence[int]], Sequence[int]],
  4003. ]:
  4004. sizes, reindex0, reindex1 = self._apply_loop_reordering(
  4005. x_vars, support_vars, sizes, memory_addrs
  4006. )
  4007. # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
  4008. x_vars = reindex0(x_vars)
  4009. if simplify_loops:
  4010. sizes, reindex2, _prune = V.graph.sizevars._simplify_loops(
  4011. x_vars,
  4012. sizes,
  4013. index_prevent_reordering(index_formulas, x_vars, sizes),
  4014. )
  4015. reindex = fuse_reindexing(reindex1, reindex2)
  4016. else:
  4017. reindex = reindex1
  4018. return sizes, reindex, reindex1
  4019. support_vars = index_vars + reduce_vars
  4020. should_merge_loops = (
  4021. not is_gpu(get_device_type(self)) or not config.loop_ordering_after_fusion
  4022. )
  4023. iter_ranges, iter_reindex, _ = simplify_and_reorder(
  4024. index_vars,
  4025. support_vars,
  4026. index_size,
  4027. should_merge_loops,
  4028. )
  4029. # Like iteration dimensions, we may also want to delay merging reduction dimensions.
  4030. # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise
  4031. # kernel, merging M and N dimension too early makes it hard to decide what loop order
  4032. # we should pick for the piontwise kernel so that it is fusible with the reduction.
  4033. reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
  4034. reduce_vars, support_vars, reduce_size, should_merge_loops
  4035. )
  4036. # retrace the loop body with simplification and reordering applied
  4037. (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
  4038. iter_ranges,
  4039. reduce_ranges,
  4040. prefix="p",
  4041. )
  4042. body = LoopBody(
  4043. body,
  4044. [iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
  4045. var_ranges,
  4046. iter_vars,
  4047. reduce_vars,
  4048. )
  4049. return (iter_ranges, reduce_ranges), body
  4050. @staticmethod
  4051. def _apply_loop_reordering(
  4052. index_vars: Sequence[sympy.Symbol],
  4053. support_vars: Sequence[sympy.Symbol],
  4054. sizes: Sequence[int],
  4055. memory_addrs: list[sympy.Expr],
  4056. priority_idx: Optional[list[int]] = None,
  4057. ) -> tuple[
  4058. list[int],
  4059. Callable[[Sequence[int]], Sequence[int]],
  4060. Callable[[Sequence[int]], Sequence[int]],
  4061. ]:
  4062. """
  4063. Shuffle the order of loops around to hopefully improve performance.
  4064. """
  4065. from .scheduler import pick_loop_order
  4066. if priority_idx is None:
  4067. priority_idx = []
  4068. try:
  4069. strides = [
  4070. V.graph.sizevars.stride_hints(expr, index_vars, support_vars)
  4071. for expr in memory_addrs
  4072. ]
  4073. assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
  4074. index_vars
  4075. )
  4076. order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
  4077. except Exception:
  4078. if config.debug:
  4079. log.warning(
  4080. "Did not simplify complex index:\n%s\n%s",
  4081. dict(zip(index_vars, sizes)),
  4082. memory_addrs,
  4083. )
  4084. order = list(range(len(sizes)))
  4085. sizes = [sizes[i] for i in order]
  4086. return sizes, same_reorder(order), inverse_reorder(order)
  4087. def get_reduction_size(self) -> Sequence[Expr]:
  4088. return self.data.get_reduction_size()
  4089. def get_reduction_type(self) -> Optional[str]:
  4090. return self.data.get_reduction_type()
  4091. def is_no_op(self) -> bool:
  4092. return self.data.is_zero_elements()
  4093. def should_allocate(self) -> bool:
  4094. return True
  4095. def constant_to_device(self, device: torch.device) -> IRNode:
  4096. """Move this to a given device. Requires that all reads are to constants."""
  4097. return self.data.constant_to_device(device)
  4098. class TemplateBuffer(OperationBuffer):
  4099. """
  4100. Represents a Triton (in the future other type) of template operator
  4101. that we can fuse an epilogue onto.
  4102. """
  4103. def __init__(
  4104. self,
  4105. layout: OutputSpec,
  4106. inputs: Sequence[IRNode],
  4107. make_kernel_render: Optional[Callable[..., Any]],
  4108. ) -> None:
  4109. super().__init__(name=None, layout=layout)
  4110. self.inputs = InputsKernel.unwrap_storage(inputs)
  4111. self.make_kernel_render = make_kernel_render
  4112. self.name = V.graph.register_buffer(self)
  4113. V.graph.register_operation(self)
  4114. def get_read_writes(self) -> dependencies.ReadWrites:
  4115. return self.extract_read_writes(normalize=True)
  4116. def extract_read_writes(self, normalize: bool = False) -> dependencies.ReadWrites:
  4117. name = self.get_name()
  4118. indexer = self.get_layout().make_indexer()
  4119. def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
  4120. assert len(rindex) == 0
  4121. return ops.store(name, indexer(index), "fake")
  4122. deps = dependencies.extract_read_writes(
  4123. dummy, self.get_size(), (), normalize=normalize
  4124. )
  4125. for inp in self.inputs:
  4126. assert isinstance(inp, (ReinterpretView, Buffer)), type(inp)
  4127. assert isinstance(inp.layout, Layout), type(inp.layout)
  4128. indexer = inp.layout.make_indexer()
  4129. def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
  4130. assert len(rindex) == 0
  4131. return ops.load(inp.get_name(), indexer(index))
  4132. deps.reads |= dependencies.extract_read_writes(
  4133. dummy, inp.get_size(), (), normalize=normalize
  4134. ).reads
  4135. return deps
  4136. def get_reduction_size(self) -> Sequence[Expr]:
  4137. return sympy.S.One
  4138. def get_reduction_type(self) -> Optional[str]:
  4139. return None
  4140. def should_allocate(self) -> bool:
  4141. return True
  4142. def simplify_and_reorder(
  4143. self,
  4144. extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
  4145. recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
  4146. ) -> tuple[tuple[Sequence[Expr], list[Expr]], Optional[LoopBody]]:
  4147. return (
  4148. (
  4149. self.get_size(),
  4150. [],
  4151. ),
  4152. None,
  4153. )
  4154. class TritonTemplateBuffer(TemplateBuffer):
  4155. def __init__(
  4156. self,
  4157. layout: Layout,
  4158. inputs: Sequence[IRNode],
  4159. make_kernel_render: Optional[Callable[_P, _T]],
  4160. mutated_inputs: Optional[Iterable[IRNode]] = None,
  4161. allowed_prologue_inps: Optional[OrderedSet[str]] = None,
  4162. ) -> None:
  4163. """
  4164. NOTE:[TritonTemplates with multiple outputs]
  4165. We want the ability for TritonTemplates to output multiple tensors. Triton
  4166. kernels have no notion of outputs and this is done by creating tensors that
  4167. are then mutated by the kernel. Currently our STORE_OUTPUT codegen doesn't
  4168. support creating multinode outputs for triton templates.
  4169. We work around this by creating an extra input buffer during the lowering
  4170. and we mark them as mutated inputs.
  4171. """
  4172. super().__init__(layout, inputs, make_kernel_render)
  4173. self.mutated_inputs = mutated_inputs
  4174. self.outputs: list[Buffer] = [self]
  4175. if mutated_inputs is not None:
  4176. # Ensure that the mutated inputs are only allowed for certain nodes
  4177. allowed_set = (
  4178. torch.ops.higher_order.flex_attention,
  4179. torch.ops.higher_order.flex_attention_backward,
  4180. )
  4181. current_node = V.graph.current_node.target
  4182. assert current_node in allowed_set, (
  4183. f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
  4184. )
  4185. assert isinstance(self.inputs[0], IRNode), type(self.inputs[0])
  4186. device = self.inputs[0].get_device()
  4187. self.outputs += [
  4188. MutationOutput(NoneLayout(device=device), buf, self)
  4189. for buf in mutated_inputs
  4190. ]
  4191. self.allowed_prologue_inps = (
  4192. allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
  4193. )
  4194. self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
  4195. self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
  4196. @cache_on_self_and_args("TritonTemplateBuffer")
  4197. def get_free_symbol_uses(
  4198. self, unbacked_only: bool = False
  4199. ) -> OrderedSet[sympy.Symbol]:
  4200. res = super().get_free_symbol_uses(unbacked_only)
  4201. subgraph_outs = self.subgraph_outs if self.subgraph_outs else []
  4202. subgraph_inps = self.subgraph_inps if self.subgraph_inps else []
  4203. for inp in subgraph_inps:
  4204. if isinstance(inp, sympy.Expr):
  4205. res.update(get_free_symbols(inp, unbacked_only))
  4206. elif isinstance(inp, IRNode):
  4207. res.update(inp.get_free_symbol_uses(unbacked_only))
  4208. else:
  4209. assert inp is None
  4210. for out in subgraph_outs:
  4211. if isinstance(out, IRNode):
  4212. res.update(out.get_free_symbol_uses(unbacked_only))
  4213. else:
  4214. assert out is None
  4215. return res
  4216. def get_outputs(self) -> list[Buffer]:
  4217. return self.outputs
  4218. def get_allowed_prologue_inps(self) -> OrderedSet[str]:
  4219. return self.allowed_prologue_inps
  4220. def __str__(self) -> str:
  4221. out = f"TritonTemplateBuffer(layout={self.layout})"
  4222. return out
  4223. PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]]
  4224. class ChoiceCaller:
  4225. """
  4226. Represents a possible choice used in autotune_process.py.
  4227. During autotuning, self.benchmark() is first called to get benchmark result,
  4228. and if this choice is selected, self.output_node() is called to get the output_node.
  4229. Children classes: TritonTemplateCaller, CUDATemplateCaller.
  4230. """
  4231. def __init__(
  4232. self,
  4233. name: str,
  4234. input_nodes: list[Buffer],
  4235. layout: Layout,
  4236. description: str,
  4237. ) -> None:
  4238. super().__init__()
  4239. self.name = name
  4240. self.layout = layout
  4241. self.input_nodes = input_nodes
  4242. # An additional description used to describe the choice (useful for
  4243. # knowing what autotuning is choosing)
  4244. self.description = description
  4245. def benchmark(self, *args: Any, out: torch.Tensor) -> float:
  4246. algo = self.to_callable()
  4247. benchmark_configs = {
  4248. "warmup": autotune_warmup,
  4249. "rep": autotune_rep,
  4250. }
  4251. if config.profile_bandwidth_with_do_bench_using_profiling:
  4252. return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs)
  4253. return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs)
  4254. def call_name(self) -> str:
  4255. raise NotImplementedError
  4256. def to_callable(self) -> Callable[..., Any]:
  4257. raise NotImplementedError
  4258. def kernel_hash_key(self) -> str:
  4259. """
  4260. Hash key for the underlying kernel. By default, we assume there are no
  4261. runtime params, so kernel hash key defaults to choice caller's hash key.
  4262. """
  4263. return self.hash_key()
  4264. def hash_key(self) -> str:
  4265. raise NotImplementedError
  4266. def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]:
  4267. raise NotImplementedError
  4268. def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
  4269. """Information returned here is logged to the autotune log file when that is enabled."""
  4270. return {}
  4271. def autoheuristic_id(self) -> str:
  4272. return "unsupported_choice"
  4273. class TritonTemplateCallerBase(ChoiceCaller):
  4274. def get_make_kernel_render(self) -> Any:
  4275. raise NotImplementedError
  4276. class MultiTemplateBuffer(TritonTemplateBuffer):
  4277. """
  4278. Represents a Buffer with multiple backing implementation choices.
  4279. Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential
  4280. epilogue we will benchmark each of the choices with the epilogue to determine an implementation.
  4281. Otherwise, the fastest base choice will be chosen.
  4282. """
  4283. def __init__(
  4284. self,
  4285. layout: Layout,
  4286. inputs: Sequence[IRNode],
  4287. choice_timings_fn: Callable[[Optional[int]], dict[ChoiceCaller, float]],
  4288. unfiltered_choices: list[ChoiceCaller],
  4289. allowed_prologue_inps: OrderedSet[str],
  4290. ) -> None:
  4291. super().__init__(
  4292. layout=layout,
  4293. inputs=inputs,
  4294. make_kernel_render=None,
  4295. allowed_prologue_inps=allowed_prologue_inps,
  4296. )
  4297. self._choice_timings_fn = choice_timings_fn
  4298. self._choice_timings: dict[Optional[int], dict[ChoiceCaller, float]] = {}
  4299. self.original_inputs = inputs
  4300. self._output_plannable = all(
  4301. isinstance(choice, TritonTemplateCallerBase)
  4302. or (
  4303. isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller)
  4304. and choice.has_out_variant
  4305. )
  4306. for choice in unfiltered_choices
  4307. )
  4308. self._make_kernel_renders: dict[Optional[int], Any] = {}
  4309. @property
  4310. def output_plannable(self) -> bool:
  4311. """
  4312. Are all possible choices TritonTemplates or Extern Kernels with out variants
  4313. """
  4314. return self._output_plannable
  4315. def choice_timings(
  4316. self, hint_override: Optional[int] = None
  4317. ) -> dict[ChoiceCaller, float]:
  4318. if hint_override not in self._choice_timings:
  4319. self._choice_timings[hint_override] = self._choice_timings_fn(hint_override)
  4320. return self._choice_timings[hint_override]
  4321. @contextlib.contextmanager
  4322. def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[None]:
  4323. assert isinstance(
  4324. caller, torch._inductor.select_algorithm.TritonTemplateCaller
  4325. ), type(caller)
  4326. assert self.layout == caller.layout
  4327. render = self.make_kernel_render
  4328. self.make_kernel_render = caller.get_make_kernel_render()
  4329. try:
  4330. yield
  4331. finally:
  4332. self.make_kernel_render = render
  4333. def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase) -> None:
  4334. assert isinstance(
  4335. caller, torch._inductor.select_algorithm.TritonTemplateCaller
  4336. ), type(caller)
  4337. assert self.get_size() == caller.layout.size
  4338. assert self.get_stride() == caller.layout.stride
  4339. self.make_kernel_render = caller.get_make_kernel_render()
  4340. def get_min_choice(
  4341. self, hint_override: Optional[int] = None
  4342. ) -> tuple[ChoiceCaller, float]:
  4343. timings = self.choice_timings(hint_override=hint_override)
  4344. min_choice = min(timings, key=timings.get) # type: ignore[arg-type]
  4345. return (min_choice, timings[min_choice])
  4346. def finalize_as_triton_callers(
  4347. self, callers: dict[Optional[int], TritonTemplateCallerBase]
  4348. ) -> None:
  4349. """Finalize with multiple callers for different hint overrides"""
  4350. for hint_override, caller in callers.items():
  4351. self._make_kernel_renders[hint_override] = caller.get_make_kernel_render()
  4352. # Set the default to be the one without hint override
  4353. self.make_kernel_render = self._make_kernel_renders[None]
  4354. class CUDATemplateBuffer(TemplateBuffer):
  4355. def __init__(
  4356. self,
  4357. layout: Layout,
  4358. inputs: Sequence[IRNode],
  4359. make_kernel_render: Callable[_P, _T],
  4360. workspace_size: int,
  4361. template: CUDATemplate,
  4362. supports_epilogue_fusion: bool,
  4363. ) -> None:
  4364. super().__init__(layout, inputs, make_kernel_render)
  4365. # Global memory (in bytes) needed for this template.
  4366. self.workspace_size = workspace_size
  4367. self.template = template
  4368. self.supports_epilogue_fusion = supports_epilogue_fusion
  4369. def get_workspace_size(self) -> int:
  4370. return self.workspace_size if self.workspace_size is not None else 0
  4371. def emulate_store_fn(self) -> None:
  4372. for output in self.get_outputs():
  4373. ops.store(output.get_name(), None, None)
  4374. class CppTemplateBuffer(TemplateBuffer):
  4375. def __init__(
  4376. self,
  4377. layout: Layout,
  4378. inputs: Sequence[IRNode],
  4379. make_kernel_render: Callable[_P, _T],
  4380. template: CUDATemplate,
  4381. choice: Any,
  4382. ) -> None:
  4383. super().__init__(layout, inputs, make_kernel_render)
  4384. self.template = template
  4385. self.choice = choice
  4386. self.outputs: Optional[list[Buffer]] = None
  4387. def get_layout(self) -> Layout:
  4388. if isinstance(self.layout, MultiOutputLayout):
  4389. assert isinstance(self.outputs, Iterable), type(self.outputs)
  4390. first_output = self.outputs[0]
  4391. assert isinstance(first_output, Buffer), type(first_output)
  4392. layout = first_output.layout
  4393. assert isinstance(layout, Layout), type(layout)
  4394. return layout
  4395. else:
  4396. return super().get_layout()
  4397. class CuteDSLTemplateBuffer(TemplateBuffer):
  4398. """
  4399. Buffer for CuteDSL (CUTLASS Python DSL) template kernels.
  4400. Similar to other template buffers but specialized for CuteDSL operations.
  4401. """
  4402. def __init__(
  4403. self,
  4404. layout: Layout,
  4405. inputs: Sequence[IRNode],
  4406. make_kernel_render: Callable[_P, _T],
  4407. template: Any,
  4408. mutated_inputs: Optional[Iterable[IRNode]] = None,
  4409. ) -> None:
  4410. super().__init__(layout, inputs, make_kernel_render)
  4411. self.template = template
  4412. self.mutated_inputs = mutated_inputs
  4413. self.outputs: list[Buffer] = [self]
  4414. if mutated_inputs is not None:
  4415. assert isinstance(self.inputs[0], IRNode), type(self.inputs[0])
  4416. device = self.inputs[0].get_device()
  4417. self.outputs += [
  4418. MutationOutput(NoneLayout(device=device), buf, self)
  4419. for buf in mutated_inputs
  4420. ]
  4421. def get_outputs(self) -> list[Buffer]:
  4422. return self.outputs
  4423. def is_node_sequence(
  4424. nodes: Sequence[Union[IRNode, Sequence[IRNode]]],
  4425. ) -> TypeIs[Sequence[IRNode]]:
  4426. return all(isinstance(n, IRNode) for n in nodes)
  4427. @ir_dataclass(frozen=False)
  4428. class InputsKernel(OperationBuffer):
  4429. inputs: Sequence[Union[IRNode, Sequence[IRNode]]]
  4430. def input_name(self, i: int) -> str:
  4431. input = self.inputs[i]
  4432. assert isinstance(input, IRNode)
  4433. return input.get_name()
  4434. def get_read_writes(self) -> dependencies.ReadWrites:
  4435. reads = OrderedSet[dependencies.Dep]()
  4436. StarDep = dependencies.StarDep
  4437. for input in self.inputs:
  4438. if isinstance(input, Sequence):
  4439. reads.update(StarDep(x.get_name()) for x in input)
  4440. elif isinstance(input, ShapeAsConstantBuffer):
  4441. # Skip creating dependency for symbolics as they're visible globally
  4442. continue
  4443. else:
  4444. reads.add(StarDep(input.get_name()))
  4445. writes = OrderedSet[dependencies.Dep](
  4446. StarDep(buf.get_name()) for buf in self.get_outputs()
  4447. )
  4448. return dependencies.ReadWrites(
  4449. reads=reads,
  4450. writes=writes,
  4451. index_exprs=OrderedSet(),
  4452. )
  4453. def get_reads(self) -> OrderedSet[Dep]:
  4454. return self.get_read_writes().reads
  4455. @classmethod
  4456. def unwrap_storage_for_input(cls, x: IRNode) -> IRNode:
  4457. if isinstance(x, TensorBox):
  4458. x = x.data
  4459. if isinstance(x, StorageBox):
  4460. x = x.data
  4461. if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
  4462. x = ExternKernel.realize_input(x)
  4463. if isinstance(x, TensorBox):
  4464. # when converting to ReinterpretView fails in the
  4465. # realize_input call above, the result will be wrapped
  4466. # into TensorBox / StorageBox pair as a result of the
  4467. # cls.copy_input call; so we should unwrap recursively
  4468. return cls.unwrap_storage_for_input(x)
  4469. if isinstance(x, TorchBindObject):
  4470. return x
  4471. assert isinstance(x, (Buffer, ReinterpretView)), type(x)
  4472. return x
  4473. @staticmethod
  4474. def unwrap_storage(
  4475. inputs: Sequence[Union[IRNode, Sequence[IRNode]]],
  4476. ) -> list[Union[IRNode, Sequence[IRNode]]]:
  4477. inputs_new: list[Union[IRNode, Sequence[IRNode]]] = []
  4478. for x in inputs:
  4479. if isinstance(x, Sequence):
  4480. x = [InputsKernel.unwrap_storage_for_input(i) for i in x]
  4481. else:
  4482. x = InputsKernel.unwrap_storage_for_input(x)
  4483. inputs_new.append(x)
  4484. return inputs_new
  4485. def is_extern(self) -> bool:
  4486. return True
  4487. def num_reads(self) -> int:
  4488. return 1
  4489. @cache_on_self_and_args("InputsKernel")
  4490. def get_free_symbol_uses(
  4491. self, unbacked_only: bool = False
  4492. ) -> OrderedSet[sympy.Symbol]:
  4493. r = OrderedSet[sympy.Symbol]()
  4494. for inp in self.inputs:
  4495. if isinstance(inp, IRNode):
  4496. r |= inp.get_free_symbol_uses(unbacked_only)
  4497. else:
  4498. for inner_inp in inp:
  4499. r |= inner_inp.get_free_symbol_uses(unbacked_only)
  4500. return r
  4501. class NopKernel(InputsKernel):
  4502. def is_no_op(self) -> bool:
  4503. return True
  4504. def get_reads(self) -> OrderedSet[Dep]:
  4505. return OrderedSet()
  4506. class ConcatKernel(NopKernel):
  4507. """
  4508. There isn't actually a real kernel for concat, we just change the
  4509. storage for the upstream data.
  4510. """
  4511. @classmethod
  4512. def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox:
  4513. """
  4514. Create the concat kernel from inputs
  4515. """
  4516. device = inputs[0].get_device()
  4517. dtype = inputs[0].get_dtype()
  4518. new_size = list(inputs[0].get_size())
  4519. offsets_start = [0]
  4520. offsets_end = [new_size[dim]]
  4521. assert 0 <= dim < len(new_size)
  4522. for i in range(1, len(inputs)):
  4523. input_size = inputs[i].get_size()
  4524. offsets_start.append(new_size[dim])
  4525. assert len(input_size) == len(new_size)
  4526. assert inputs[i].get_dtype() == dtype
  4527. assert inputs[i].get_device() == device
  4528. for j in range(len(new_size)):
  4529. if j == dim:
  4530. new_size[j] = new_size[j] + input_size[j]
  4531. else:
  4532. new_size[j] = V.graph.sizevars.check_equals_and_simplify(
  4533. new_size[j], input_size[j]
  4534. )
  4535. offsets_end.append(new_size[dim])
  4536. output_stride: Sequence[int] = FlexibleLayout.contiguous_strides(new_size)
  4537. if config.comprehensive_padding:
  4538. # Ensure the output stride matches the alignment requirements
  4539. output_stride = Layout._pad_strides(
  4540. output_stride, new_size, inputs[0].dtype
  4541. )
  4542. # If any of the inputs is in CL format, use CL format for the output
  4543. for i in range(len(inputs)):
  4544. x = inputs[i]
  4545. if is_storage_and_layout(x):
  4546. layout = x.get_layout()
  4547. if isinstance(
  4548. layout, FixedLayout
  4549. ) and Layout.is_channels_last_contiguous(layout.size, layout.stride):
  4550. # use CL stride for the output
  4551. output_stride = make_channels_last_strides_for(new_size)
  4552. break
  4553. any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs)
  4554. fx_node_args = V.graph.current_node.args[0]
  4555. assert isinstance(fx_node_args, list), type(fx_node_args)
  4556. # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output
  4557. if any_input_is_storage_and_layout is False and any(
  4558. "val" in arg.meta
  4559. and (
  4560. arg.meta["val"].is_contiguous(memory_format=torch.channels_last)
  4561. or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d)
  4562. )
  4563. for arg in fx_node_args
  4564. ):
  4565. output_stride = make_channels_last_strides_for(new_size)
  4566. is_pinned = all(
  4567. is_storage_and_layout(x) and x.get_layout().is_pinned for x in inputs
  4568. )
  4569. assert device is not None
  4570. concat_kernel = ConcatKernel(
  4571. name=None,
  4572. layout=FixedLayout(
  4573. device=device,
  4574. dtype=dtype,
  4575. size=new_size,
  4576. stride=output_stride,
  4577. is_pinned=is_pinned,
  4578. ),
  4579. inputs=[],
  4580. )
  4581. kernel = StorageBox(concat_kernel)
  4582. op_names = []
  4583. for i, inp in enumerate(inputs):
  4584. assert isinstance(inp, (BaseView, MutableBox)), type(inp)
  4585. input_buffer = cls.realize_into(
  4586. inp,
  4587. SliceView.create(
  4588. kernel, dim, offsets_start[i], offsets_end[i], clamp=False
  4589. ),
  4590. )
  4591. assert isinstance(input_buffer, Buffer), type(input_buffer)
  4592. assert isinstance(concat_kernel.inputs, list), type(concat_kernel.inputs)
  4593. concat_kernel.inputs.append(input_buffer)
  4594. if isinstance(inp.data, BaseView):
  4595. input_unwrapped = inp.data.unwrap_view()
  4596. else:
  4597. input_unwrapped = inp.data
  4598. if (
  4599. isinstance(input_unwrapped, StorageBox)
  4600. and input_unwrapped.is_input_buffer()
  4601. and (dev := inp.get_device()) is not None
  4602. and is_gpu(dev.type)
  4603. and not is_dynamic(input_buffer)
  4604. ):
  4605. op_names.append(input_buffer.get_operation_name())
  4606. if len(op_names) > 1 and V.graph.has_feature(device, BackendFeature.FOREACH):
  4607. V.graph.register_operation_list(op_names)
  4608. concat_kernel.name = V.graph.register_buffer(concat_kernel)
  4609. concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs)
  4610. V.graph.register_operation(concat_kernel)
  4611. return kernel
  4612. @classmethod
  4613. def can_realize_into_without_copy(
  4614. cls, src: IRNode, dst: Optional[IRNode] = None
  4615. ) -> bool:
  4616. if isinstance(src, TensorBox):
  4617. # unwrap a TensorBox
  4618. return cls.can_realize_into_without_copy(src.data, dst)
  4619. assert isinstance(src, (BaseView, StorageBox)), type(src)
  4620. if isinstance(src.data, MultiTemplateBuffer):
  4621. if (
  4622. not isinstance(src.data.layout, FixedLayout)
  4623. or not src.data.output_plannable
  4624. ):
  4625. return False
  4626. # we call can_realize_into_without_copy in cat lowering before we've decided
  4627. # on output format, optimistically assume layout matches
  4628. if dst is None:
  4629. return True
  4630. # otherwise, check equality of layouts
  4631. if not len(src.get_stride()) == len(dst.get_stride()):
  4632. return False
  4633. return all(
  4634. V.graph.sizevars.statically_known_equals(s1, s2)
  4635. for s1, s2 in zip(src.get_stride(), dst.get_stride())
  4636. )
  4637. return (
  4638. hasattr(src.data, "layout")
  4639. and isinstance(src.data.layout, FlexibleLayout)
  4640. and not isinstance(src.data, ExternKernelAlloc)
  4641. )
  4642. @cache_on_self_and_args("ConcatKernel")
  4643. def get_free_symbol_uses(
  4644. self, unbacked_only: bool = False
  4645. ) -> OrderedSet[sympy.Symbol]:
  4646. return NopKernel.get_free_symbol_uses(self, unbacked_only)
  4647. @classmethod
  4648. def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode:
  4649. # Attempt to turn this into a ReinterpretView rather than assert.
  4650. # This has concessions around layout, as as_storage_and_layout
  4651. # can cause us to go from flexible to fixed layout.
  4652. if not isinstance(dst, ReinterpretView):
  4653. if is_storage_and_layout(dst):
  4654. storage, layout = as_storage_and_layout(dst)
  4655. dst = ReinterpretView(data=storage, layout=layout)
  4656. assert isinstance(dst, ReinterpretView), type(dst)
  4657. if isinstance(src, TensorBox):
  4658. # unwrap a TensorBox
  4659. return cls.realize_into(src.data, dst)
  4660. if isinstance(src, StorageBox):
  4661. src.realize()
  4662. # ExternKernelAlloc has specific requirements for output layout, should create a copy
  4663. assert hasattr(src.data, "layout")
  4664. if cls.can_realize_into_without_copy(src, dst):
  4665. src.data.layout = NonOwningLayout(dst)
  4666. return src.data
  4667. # introduce a copy
  4668. pw = Pointwise.create(
  4669. device=src.get_device(),
  4670. dtype=src.get_dtype(),
  4671. inner_fn=src.make_loader(),
  4672. ranges=[
  4673. V.graph.sizevars.check_equals_and_simplify(a, b)
  4674. for a, b in zip(src.get_size(), dst.get_size())
  4675. ],
  4676. )
  4677. return cls.realize_into(pw, dst)
  4678. def should_allocate(self) -> bool:
  4679. return True
  4680. @ir_dataclass(frozen=False)
  4681. class ExternKernel(InputsKernel):
  4682. """
  4683. A class that represents Kernels which are not directly lowered to Inductor
  4684. Loop Level IR, such as custom operators, or aten operators which we fallback to.
  4685. """
  4686. constant_args: Sequence[Any] = ()
  4687. kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
  4688. output_view: Optional[ReinterpretView] = None
  4689. python_kernel_name: Optional[str] = None
  4690. cpp_kernel_name: Optional[str] = None
  4691. # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel
  4692. # We shouldn't need to do this since the information can be retrieved from op_overload._schema.
  4693. ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
  4694. default_factory=list
  4695. )
  4696. op_overload: Optional[_OpOverloads] = None
  4697. arg_properties: Optional[list[dict[str, Any]]] = None
  4698. allarg_properties: dict[str, dict[str, Any]] = dataclasses.field(
  4699. default_factory=dict
  4700. )
  4701. kwarg_properties: Optional[dict[str, dict[str, Any]]] = None
  4702. unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
  4703. default_factory=dict
  4704. )
  4705. mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list)
  4706. def __init__(
  4707. self,
  4708. name: Optional[str],
  4709. layout: OutputSpec,
  4710. inputs: Sequence[Union[IRNode, Sequence[IRNode]]],
  4711. constant_args: Sequence[Any] = (),
  4712. kwargs: Optional[dict[str, Any]] = None,
  4713. output_view: Optional[ReinterpretView] = None,
  4714. python_kernel_name: Optional[str] = None,
  4715. cpp_kernel_name: Optional[str] = None,
  4716. ordered_kwargs_for_cpp_kernel: Iterable[str] = (),
  4717. op_overload: Optional[_OpOverloads] = None,
  4718. ) -> None:
  4719. super().__init__(
  4720. name=name,
  4721. layout=layout,
  4722. inputs=inputs,
  4723. )
  4724. self.constant_args = constant_args
  4725. self.kwargs = kwargs if kwargs else {}
  4726. self.output_view = output_view
  4727. self.op_overload = op_overload
  4728. self.set_cpp_kernel_name(cpp_kernel_name)
  4729. self.set_python_kernel_name(python_kernel_name)
  4730. self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
  4731. self.collect_arg_kwarg_properties()
  4732. self.unbacked_bindings = {}
  4733. self.mutation_outputs = []
  4734. self.fx_node = V.graph.current_node
  4735. def get_outputs(self) -> list[Buffer]:
  4736. return [self, *self.mutation_outputs]
  4737. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  4738. return OrderedSet()
  4739. def collect_arg_kwarg_properties(self) -> None:
  4740. # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional
  4741. # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen
  4742. self.arg_properties = (
  4743. [
  4744. {
  4745. "name": x.name,
  4746. "type": x.real_type,
  4747. "default_value": x.default_value,
  4748. }
  4749. for x in self.op_overload._schema.arguments
  4750. if not x.kwarg_only
  4751. ]
  4752. if isinstance(self.op_overload, torch._ops.OpOverload)
  4753. else [{} for i in range(len(self.inputs))]
  4754. )
  4755. self.allarg_properties = (
  4756. {
  4757. x.name: {"type": x.real_type, "default_value": x.default_value}
  4758. for x in self.op_overload._schema.arguments
  4759. }
  4760. if isinstance(self.op_overload, torch._ops.OpOverload)
  4761. else {}
  4762. )
  4763. # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes
  4764. # ordered_kwargs_for_cpp_kernel is explicitly passed in.
  4765. if isinstance(self.op_overload, torch._ops.OpOverload):
  4766. if not self.ordered_kwargs_for_cpp_kernel:
  4767. self.ordered_kwargs_for_cpp_kernel = [
  4768. x.name for x in self.op_overload._schema.arguments if x.kwarg_only
  4769. ]
  4770. self.schema_kwargs = [
  4771. x for x in self.op_overload._schema.arguments if x.kwarg_only
  4772. ]
  4773. else:
  4774. self.schema_kwargs = []
  4775. def decide_layout(self) -> None:
  4776. if isinstance(self.layout, FlexibleLayout):
  4777. self.apply_constraint()
  4778. self.freeze_layout()
  4779. def codegen_comment(self, wrapper: PythonWrapperCodegen) -> None:
  4780. origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper)
  4781. if origin_str:
  4782. wrapper.make_comment(origin_str)
  4783. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  4784. raise NotImplementedError
  4785. def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
  4786. self.cpp_kernel_name = cpp_kernel_name
  4787. if not V.graph.cpp_wrapper or not isinstance(
  4788. self.op_overload, torch._ops.OpOverload
  4789. ):
  4790. return
  4791. kernel = self.op_overload
  4792. if self.cpp_kernel_name is None:
  4793. # Try to construct cpp_kernel_name from op_overload
  4794. if kernel.namespace == "aten":
  4795. # Calling with the default kernel name can lead to ambiguous behavior like the following example.
  4796. # repeat_interleave(const at::Tensor & repeats, std::optional<int64_t> output_size=std::nullopt)
  4797. # repeat_interleave(const at::Tensor & self, int64_t repeats,
  4798. # std::optional<int64_t> dim=std::nullopt, std::optional<int64_t> output_size=std::nullopt)
  4799. opname = (
  4800. kernel.__name__.split(".")[0]
  4801. if kernel._overloadname == "default"
  4802. else kernel.__name__.replace(".", "_")
  4803. )
  4804. self.cpp_kernel_name = f"at::_ops::{opname}::call"
  4805. else:
  4806. self.cpp_kernel_name = kernel._schema.name
  4807. def set_python_kernel_name(self, python_kernel_name: Optional[str]) -> None:
  4808. self.python_kernel_name = python_kernel_name
  4809. if python_kernel_name is not None:
  4810. return
  4811. kernel = self.op_overload
  4812. if kernel is None:
  4813. pass
  4814. elif isinstance(kernel, torch._ops.HigherOrderOperator):
  4815. self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}"
  4816. else:
  4817. self.python_kernel_name = (
  4818. f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
  4819. )
  4820. def get_kernel_name(self) -> str:
  4821. from .codegen.cpp_wrapper_cpu import CppWrapperCpu
  4822. device = d.type if (d := self.get_device()) else V.graph.device_type
  4823. if V.graph.fx_wrapper:
  4824. assert self.python_kernel_name is not None
  4825. return self.python_kernel_name
  4826. elif V.graph.cpp_wrapper:
  4827. assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type(
  4828. V.graph.wrapper_code
  4829. )
  4830. assert self.cpp_kernel_name is not None
  4831. return V.graph.wrapper_code.get_c_shim_func_name(
  4832. self.cpp_kernel_name, device
  4833. )
  4834. else:
  4835. assert self.python_kernel_name is not None
  4836. return self.python_kernel_name
  4837. @staticmethod
  4838. def copy_input(x: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]:
  4839. pw = Pointwise.create(
  4840. device=x.get_device(),
  4841. dtype=x.get_dtype(),
  4842. inner_fn=x.make_loader(),
  4843. ranges=x.get_size(),
  4844. origin_node=x.get_origin_node(),
  4845. traceback=x.get_traceback(),
  4846. )
  4847. pw.realize()
  4848. return pw
  4849. @classmethod
  4850. def process_kernel(
  4851. cls, kernel: _OpOverloads, *args: Any, **kwargs: Any
  4852. ) -> tuple[
  4853. Any,
  4854. list[Any],
  4855. list[Any],
  4856. Callable[[Any, Any], Any],
  4857. Optional[dict[sympy.Symbol, pytree.KeyPath]],
  4858. ]:
  4859. binded_args = {"args": args, "kwargs": kwargs}
  4860. args_flat, args_spec = pytree.tree_flatten(binded_args)
  4861. is_arg_tensor = []
  4862. # tensor_args can be either tensor or torchbind objects
  4863. tensor_args = []
  4864. non_tensor_args: list[Any] = []
  4865. for arg in args_flat:
  4866. is_arg_tensor.append(
  4867. isinstance(arg, IRNode) and not isinstance(arg, GeneratorState)
  4868. )
  4869. if is_arg_tensor[-1]:
  4870. tensor_args.append(arg)
  4871. else:
  4872. if isinstance(arg, Expr):
  4873. arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
  4874. non_tensor_args.append(arg)
  4875. def unflatten_args(
  4876. new_tensor_args: Sequence[_T], new_non_tensor_args: Sequence[_T]
  4877. ) -> tuple[list[_T], dict[str, _T]]:
  4878. result = []
  4879. it_tensors = iter(new_tensor_args)
  4880. it_non_tensors = iter(new_non_tensor_args)
  4881. for is_tensor in is_arg_tensor:
  4882. if is_tensor:
  4883. result.append(next(it_tensors))
  4884. else:
  4885. result.append(next(it_non_tensors))
  4886. r = pytree.tree_unflatten(result, args_spec)
  4887. return r.get("args", []), r.get("kwargs", {})
  4888. tensor_args = [cls.realize_input(x) for x in tensor_args]
  4889. # freeze layout otherwise our output stride calculation might
  4890. # become incorrect
  4891. for x in tensor_args:
  4892. if is_storage_and_layout(x):
  4893. as_storage_and_layout(x, freeze=True)
  4894. # Rerun fake tensor propagation, because Inductor may have changed the
  4895. # strides of inputs and we need to determine accurately what the
  4896. # output stride will be.
  4897. example_args: list[
  4898. Union[
  4899. torch.Tensor, torch._C.ScriptObject, FakeScriptObject, torch.Generator
  4900. ]
  4901. ] = []
  4902. # We need to retain the constant values of fake tensors that we originally
  4903. # propagated the graph with, because for some operators running without a
  4904. # constant would trigger an error / DataDependentException
  4905. for x in tensor_args:
  4906. # if x is a view of a constant, we need to realize the view
  4907. # (we can't pass the constant into the kernel directly)
  4908. if not isinstance(x, BaseView) and x.get_name() in V.graph.constants:
  4909. example_args.append(V.graph.constants[x.get_name()])
  4910. elif (
  4911. not isinstance(x, BaseView)
  4912. and x.get_name() in V.graph.torchbind_constants
  4913. ):
  4914. example_args.append(V.graph.torchbind_constants[x.get_name()])
  4915. elif isinstance(x, TorchBindObject):
  4916. example_args.append(x.get_value())
  4917. elif isinstance(x, torch._inductor.ir.GeneratorState):
  4918. device_index = x.device.index
  4919. assert x.device.type == "cuda" and device_index is not None
  4920. example_args.append(
  4921. torch.cuda.default_generators[device_index].clone_state()
  4922. )
  4923. else:
  4924. example_args.append(ir_node_to_tensor(x, guard_shape=True))
  4925. new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
  4926. example_output = kernel(*new_args, **new_kwargs)
  4927. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
  4928. if shape_env := V.fake_mode.shape_env:
  4929. node_meta_val = V.current_node.meta.get("val")
  4930. ctx: AbstractContextManager[None] = nullcontext()
  4931. if V.current_node.target == torch._higher_order_ops.effects.with_effects:
  4932. # remove the first effect token in meta["val"] and meta["unbacked_bindings"]
  4933. node_meta_val = node_meta_val[1]
  4934. ctx = _remove_effect_token_unbacked_bindings(V.current_node)
  4935. with ctx:
  4936. rebind_unbacked(shape_env, V.current_node, example_output)
  4937. unbacked_bindings = compute_unbacked_bindings(
  4938. shape_env, example_output, node_meta_val
  4939. )
  4940. example_out_li = (
  4941. [example_output]
  4942. if not isinstance(example_output, (list, tuple))
  4943. else example_output
  4944. )
  4945. for t in example_out_li:
  4946. if isinstance(t, torch.Tensor) and t.is_sparse:
  4947. msg = "sparsity not handled. Please file issue for sparse inference weights."
  4948. if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
  4949. msg = f"{msg} Found from : \n {stack_trace}"
  4950. V.graph.disable_cudagraphs_reason = msg
  4951. return (
  4952. example_output,
  4953. tensor_args,
  4954. non_tensor_args,
  4955. unflatten_args,
  4956. unbacked_bindings,
  4957. )
  4958. @classmethod
  4959. def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView:
  4960. """
  4961. In order to pass this to an extern kernel we need a
  4962. ReinterpretView not a View. This allows us to avoid some
  4963. unneeded copies.
  4964. """
  4965. assert isinstance(x, BaseView), type(x)
  4966. if isinstance(x, ReinterpretView):
  4967. return x
  4968. # NOTE: Don't use extract_read_writes here as it fails when
  4969. # make_loader() inlines the computation
  4970. x_unwrap_view = x.unwrap_view()
  4971. buf = V.graph.get_buffer(x_unwrap_view.get_name())
  4972. assert buf is not None
  4973. x_unwrap_view_fx_node = buf.get_origin_node()
  4974. # Prefer channels last format according to how the format is set from eager.
  4975. if (
  4976. x_unwrap_view_fx_node is not None
  4977. and "val" in x_unwrap_view_fx_node.meta
  4978. and isinstance(x_unwrap_view, (ReinterpretView, Buffer, MutableBox))
  4979. and isinstance(x_unwrap_view.layout, FlexibleLayout)
  4980. and (
  4981. x_unwrap_view_fx_node.meta["val"].is_contiguous(
  4982. memory_format=torch.channels_last
  4983. )
  4984. or x_unwrap_view_fx_node.meta["val"].is_contiguous(
  4985. memory_format=torch.channels_last_3d
  4986. )
  4987. )
  4988. ):
  4989. x_unwrap_view.freeze_layout_with_same_order(
  4990. make_channels_last_strides_for(x_unwrap_view.get_size())
  4991. )
  4992. else:
  4993. x_unwrap_view.freeze_layout()
  4994. index_args, var_ranges = dependencies.index_vars_squeeze(
  4995. x.get_size(), prefix="r"
  4996. )
  4997. range_vars = index_args[0]
  4998. index = x.make_indexer()(range_vars)
  4999. index = V.graph.sizevars.simplify_with_ranges(index, var_ranges)
  5000. strides = V.graph.sizevars.stride_vars(index, range_vars)
  5001. offset = V.graph.sizevars.offset_var(index, range_vars)
  5002. expected = sympy_dot(range_vars, strides) + offset
  5003. if index != expected:
  5004. log.debug(
  5005. "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
  5006. strides,
  5007. offset,
  5008. index,
  5009. )
  5010. raise NotImplementedError
  5011. return ReinterpretView(
  5012. data=x.data,
  5013. layout=FixedLayout(
  5014. device=x.get_device_or_error(),
  5015. dtype=x.get_dtype(),
  5016. size=x.get_size(),
  5017. stride=strides,
  5018. offset=offset,
  5019. is_pinned=False,
  5020. ),
  5021. )
  5022. @classmethod
  5023. def realize_input(cls, x: IRNode) -> IRNode:
  5024. if x is None:
  5025. return NoneAsConstantBuffer()
  5026. if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
  5027. return ShapeAsConstantBuffer(expr=x)
  5028. if isinstance(x, Constant):
  5029. return V.graph.add_tensor_constant(
  5030. torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
  5031. )
  5032. if isinstance(x, ConstantBuffer):
  5033. return x
  5034. if isinstance(x, TensorBox):
  5035. return cls.realize_input(x.data)
  5036. if isinstance(x, ReinterpretView):
  5037. return ReinterpretView(
  5038. data=cls.realize_input(x.data), layout=x.get_layout()
  5039. )
  5040. if isinstance(x, BaseView):
  5041. x.realize()
  5042. if is_storage_and_layout(x.unwrap_view()):
  5043. try:
  5044. return cls.convert_to_reinterpret_view(x)
  5045. except NotImplementedError:
  5046. pass
  5047. if isinstance(x, StorageBox):
  5048. # TODO(jansel): impose layout preference on realized buffer
  5049. x.realize()
  5050. return x
  5051. if isinstance(x, (NonTensorObj, ShapeAsConstantBuffer)):
  5052. return x
  5053. return cls.copy_input(x)
  5054. @classmethod
  5055. def require_stride1(cls, x: IRNode) -> IRNode:
  5056. if is_storage_and_layout(x):
  5057. if len(x.get_stride()) == 0:
  5058. return x
  5059. for stride in x.get_stride():
  5060. if stride == 1:
  5061. return x
  5062. return cls.copy_input(x)
  5063. @classmethod
  5064. def require_strides(
  5065. cls,
  5066. x: IRNode,
  5067. order: Optional[Sequence[int]] = None,
  5068. exact_strides: Optional[Sequence[_IntLike]] = None,
  5069. allow_padding: bool = False,
  5070. ) -> IRNode:
  5071. assert order is not None or exact_strides is not None
  5072. # Layout generally doesn't matter, but some consuming external ops might have requirements
  5073. if x.get_numel() in (0, 1) and not exact_strides:
  5074. return x
  5075. # require x to have the layout
  5076. if is_storage_and_layout(x):
  5077. if isinstance(x.get_layout(), FlexibleLayout):
  5078. if order:
  5079. # If the the FlexibleLayout already has the size and stride in the required order,
  5080. # freeze it to a FixedLayout by using its current size and stride.
  5081. # The behavior of using its current size and stride or the given order can be different
  5082. # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1:
  5083. # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last),
  5084. # the current size and stride already satisfies this order.
  5085. # However by freezing it to the required order, the layout will be changed to:
  5086. # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary.
  5087. use_current_stride_order = is_stride_order_storage_and_layout(
  5088. x, order
  5089. ) and not free_unbacked_symbols(x.get_layout().stride)
  5090. # fix flexiblelayout to be FixedLayout with stride_order
  5091. as_storage_and_layout(
  5092. x,
  5093. freeze=True,
  5094. want_contiguous=False,
  5095. stride_order=(
  5096. get_stride_order(
  5097. V.graph.sizevars.size_hints_or_throw(
  5098. x.get_layout().stride
  5099. )
  5100. )
  5101. if use_current_stride_order
  5102. else order
  5103. ),
  5104. allow_padding=allow_padding,
  5105. )
  5106. return x
  5107. else:
  5108. # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides.
  5109. as_storage_and_layout(
  5110. x,
  5111. freeze=True,
  5112. want_contiguous=False,
  5113. stride_order=None,
  5114. allow_padding=allow_padding,
  5115. exact_strides=exact_strides,
  5116. )
  5117. return x
  5118. elif isinstance(x.get_layout(), (FixedLayout, NonOwningLayout)) and (
  5119. (order and x.get_layout().is_stride_ordered(order))
  5120. or (
  5121. exact_strides
  5122. and significant_strides_equal(
  5123. exact_strides, x.get_layout().stride, x.get_size()
  5124. )
  5125. )
  5126. ):
  5127. return (
  5128. try_match_insignificant_strides(x, exact_strides)
  5129. if exact_strides is not None
  5130. else x
  5131. )
  5132. elif isinstance(
  5133. (mutation_layout := x.get_layout()), MutationLayoutSHOULDREMOVE
  5134. ):
  5135. if isinstance(
  5136. (real_layout := mutation_layout.real_layout()), FlexibleLayout
  5137. ):
  5138. raise AssertionError(
  5139. "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout"
  5140. )
  5141. elif isinstance(real_layout, FixedLayout) and (
  5142. (order and real_layout.is_stride_ordered(order))
  5143. or (
  5144. exact_strides
  5145. and significant_strides_equal(
  5146. exact_strides, real_layout.stride, x.get_size()
  5147. )
  5148. )
  5149. ):
  5150. return x
  5151. # TODO - Storage to InputBuffer
  5152. if isinstance(x, InputBuffer) and (
  5153. (order and x.get_layout().is_stride_ordered(order))
  5154. or (
  5155. exact_strides
  5156. and significant_strides_equal(
  5157. exact_strides, x.get_layout().stride, x.get_size()
  5158. )
  5159. )
  5160. ):
  5161. return x
  5162. if (
  5163. isinstance(x, TensorBox)
  5164. and isinstance(x.data, BaseView)
  5165. and not isinstance(x.data, ReinterpretView)
  5166. and is_storage_and_layout(unwrap_view := x.unwrap_view())
  5167. and hasattr(unwrap_view, "data")
  5168. and not isinstance(unwrap_view.data, ExternKernelAlloc)
  5169. ):
  5170. try:
  5171. x.data = cls.convert_to_reinterpret_view(x.data)
  5172. if order:
  5173. return cls.require_stride_order(
  5174. x, order, allow_padding=allow_padding
  5175. )
  5176. elif exact_strides:
  5177. return cls.require_exact_strides(
  5178. x, exact_strides, allow_padding=allow_padding
  5179. )
  5180. except NotImplementedError:
  5181. pass
  5182. # Preserve ExpandView representation that would be lost during copy_input
  5183. # Without representation of the expand in inductor IR, in codegen we end up
  5184. # launching a grid for the full size tensor and doing redundant computation
  5185. # across expanded dims.
  5186. # TODO: could also be good to have a codegen fix to recognize overlapping elements
  5187. expanded_dims: Optional[list[int]] = None
  5188. orig_size = x.get_size()
  5189. if exact_strides is not None:
  5190. sizevars = V.graph.sizevars
  5191. expanded_dims = [
  5192. i
  5193. for i in range(len(x.get_size()))
  5194. if sizevars.statically_known_equals(exact_strides[i], 0)
  5195. and sizevars.statically_known_geq(x.get_size()[i], 2)
  5196. ]
  5197. for dim in expanded_dims:
  5198. x = torch._inductor.lowering.slice_(x, dim, 0, 1)
  5199. # Although this is a clone, inductor is good about fusing clones into previous
  5200. # operations if they weren't realized and their layouts were flexible.
  5201. x = cls.copy_input(x)
  5202. as_storage_and_layout(
  5203. x,
  5204. freeze=True,
  5205. want_contiguous=False,
  5206. stride_order=order,
  5207. allow_padding=allow_padding,
  5208. exact_strides=exact_strides,
  5209. )
  5210. if order:
  5211. assert is_stride_order_storage_and_layout(x, order)
  5212. elif expanded_dims:
  5213. assert orig_size is not None and exact_strides is not None
  5214. x = torch._inductor.lowering.expand(x, orig_size)
  5215. # the expand will sometimes may change insignificant strides, so match them back
  5216. return try_match_insignificant_strides(x, exact_strides)
  5217. return x
  5218. @classmethod
  5219. def require_exact_strides(
  5220. cls, x: IRNode, exact_strides: Sequence[_IntLike], allow_padding: bool = False
  5221. ) -> IRNode:
  5222. return cls.require_strides(
  5223. x, exact_strides=exact_strides, allow_padding=allow_padding
  5224. )
  5225. @classmethod
  5226. def require_stride_order(
  5227. cls, x: IRNode, order: Sequence[int], allow_padding: bool = False
  5228. ) -> IRNode:
  5229. return cls.require_strides(x, order=order, allow_padding=allow_padding)
  5230. @classmethod
  5231. def require_channels_last(cls, x: IRNode) -> IRNode:
  5232. return cls.require_stride_order(x, NHWC_STRIDE_ORDER)
  5233. @classmethod
  5234. def require_channels_last_3d(cls, x: IRNode) -> IRNode:
  5235. return cls.require_stride_order(x, NHWDC_STRIDE_ORDER)
  5236. @classmethod
  5237. def require_contiguous(cls, x: IRNode) -> IRNode:
  5238. def is_mkldnn_tensor(x: IRNode) -> bool:
  5239. try:
  5240. name = x.get_name()
  5241. except (AttributeError, NotImplementedError):
  5242. return False
  5243. return name in V.graph.constants and V.graph.constants[name].is_mkldnn
  5244. # TODO move this to the more proper places
  5245. if is_mkldnn_tensor(x):
  5246. return x
  5247. else:
  5248. return cls.require_exact_strides(
  5249. x, FlexibleLayout.contiguous_strides(x.get_size())
  5250. )
  5251. @classmethod
  5252. def require_contiguous_strides(cls, x: IRNode) -> IRNode:
  5253. # TODO: combine this with require_contiguous after
  5254. # https://github.com/pytorch/pytorch/pull/148235 lands.
  5255. return cls.require_exact_strides(
  5256. x, FlexibleLayout.contiguous_strides(x.get_size())
  5257. )
  5258. def apply_constraint(self) -> None:
  5259. pass
  5260. def fill_non_provided_args(
  5261. self, args: Sequence[Any], kwargs: dict[str, Any]
  5262. ) -> Sequence[Any]:
  5263. # Previously, we want to maintain forward-compatibility by skipping
  5264. # default args in the serialized artifacts in fbcode. However,
  5265. # some of our shim interfaces require default values being OrderedSet.
  5266. # Discussed with Sherlock offline and we decided to allow serializing
  5267. # default args into the C++ wrapper code for now. We will refine this
  5268. # part if we see real FC requirement. More details related to FC
  5269. # can be found at:
  5270. # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
  5271. assert isinstance(args, Sequence), type(args)
  5272. if not isinstance(args, list):
  5273. args = list(args)
  5274. assert self.arg_properties, "ExternKernel.arg_properties should not be empty"
  5275. n_args = len(args)
  5276. n_pos_args = len(self.arg_properties)
  5277. # For cpp wrapper, if some positional args are not provided, we need to check
  5278. # if they're in the kwargs or use their default value
  5279. if n_args < n_pos_args:
  5280. log.debug(
  5281. "%s has %d unprovided positional arguments. "
  5282. "Will check if they are in the keyword arguments or will use default values.",
  5283. self.op_overload,
  5284. n_pos_args - n_args,
  5285. )
  5286. for i in range(n_args, n_pos_args):
  5287. arg_name = self.arg_properties[i]["name"]
  5288. args.append(
  5289. kwargs[arg_name]
  5290. if arg_name in kwargs
  5291. else self.arg_properties[i]["default_value"]
  5292. )
  5293. return args
  5294. def codegen_const_args(self, names: Optional[list[str]] = None) -> list[str]:
  5295. if V.graph.cpp_wrapper:
  5296. result = []
  5297. # Aten ops follow the convention that tensor args are before non-tensor args,
  5298. # in which case the following 'len(self.inputs) + i' logic works. But this
  5299. # may not be true for other ops, and if that is the case, caller needs to
  5300. # pass in a list of const arg names for arg_properties lookup.
  5301. name_to_arg_properties = None
  5302. if names and self.arg_properties:
  5303. assert len(self.constant_args) == len(names), (
  5304. "names passed to codegen_const_args does not match self.constant_args"
  5305. )
  5306. name_to_arg_properties = {
  5307. arg.get("name"): arg for arg in self.arg_properties
  5308. }
  5309. for i, x in enumerate(self.constant_args):
  5310. if name_to_arg_properties is not None:
  5311. assert names is not None
  5312. prop = name_to_arg_properties.get(names[i])
  5313. type_ = prop.get("type") if prop else None
  5314. else:
  5315. idx = len(self.inputs) + i
  5316. type_ = (
  5317. self.arg_properties[idx].get("type")
  5318. if self.arg_properties and idx < len(self.arg_properties)
  5319. else None
  5320. )
  5321. result.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
  5322. return result
  5323. else:
  5324. return [V.graph.wrapper_code.val_to_arg_str(a) for a in self.constant_args]
  5325. def codegen_args(self) -> list[str]:
  5326. if V.graph.cpp_wrapper and self.op_overload is not None:
  5327. # cpp wrapper needs special logic to fill in missing args with default values
  5328. inputs = self.fill_non_provided_args(
  5329. [*self.inputs, *self.constant_args], self.kwargs
  5330. )
  5331. # fill_non_provided_args has handled constant args, so no need to codegen for that later
  5332. need_codegen_constant_args = False
  5333. else:
  5334. inputs = self.inputs
  5335. need_codegen_constant_args = True
  5336. args = []
  5337. for i, x in enumerate(inputs):
  5338. if V.graph.cpp_wrapper:
  5339. assert self.arg_properties and i < len(self.arg_properties), (
  5340. "Invalid access to ExternKernel.arg_properties"
  5341. )
  5342. type_ = self.arg_properties[i].get("type")
  5343. args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
  5344. else:
  5345. args.append(V.graph.wrapper_code.val_to_arg_str(x))
  5346. if need_codegen_constant_args:
  5347. args.extend(self.codegen_const_args())
  5348. return args
  5349. def get_kwargs_value(self, arg_name: str, **kwargs: Any) -> Any:
  5350. """Given an argument name, queries for values in (in order):
  5351. 1. any provided kwargs for this function.
  5352. 2. the class self.kwargs member.
  5353. 3. any available default arguments in self.allarg_properties."""
  5354. if arg_name in kwargs:
  5355. return kwargs.get(arg_name)
  5356. if arg_name in self.kwargs:
  5357. return self.kwargs.get(arg_name)
  5358. if (arg := self.allarg_properties.get(arg_name)) is not None:
  5359. return arg.get("default_value")
  5360. raise AssertionError(f"{arg_name} not in self.allarg_properties")
  5361. def codegen_kwargs(self, skip_out: bool = False) -> list[str]:
  5362. if V.graph.cpp_wrapper:
  5363. if self.op_overload is not None and len(self.schema_kwargs) == 0:
  5364. # All the args should have been generated by fill_non_provided_args in codegen_args
  5365. return []
  5366. kwargs = []
  5367. for arg_name in self.ordered_kwargs_for_cpp_kernel:
  5368. if skip_out and arg_name == "out":
  5369. # ExternKernelOut has its own logic for inserting the out parameter
  5370. continue
  5371. v = self.get_kwargs_value(arg_name)
  5372. if isinstance(v, Expr):
  5373. kwargs.append(v)
  5374. else:
  5375. assert self.allarg_properties is not None
  5376. type_ = self.allarg_properties.get(arg_name, {}).get("type")
  5377. kwargs.append(V.graph.wrapper_code.val_to_arg_str(v, type_))
  5378. else:
  5379. kwargs = [
  5380. f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"
  5381. for k, v in self.kwargs.items()
  5382. ]
  5383. return kwargs
  5384. def get_op_name(self) -> str:
  5385. if self.fx_node is not None:
  5386. target = self.fx_node.target
  5387. op_namespace = getattr(target, "__module__", "unknown_namespace")
  5388. op_namespace = op_namespace.replace("._ops.", ".ops.")
  5389. op_namespace = op_namespace.rsplit(".", 1)[0]
  5390. op_name = f"{op_namespace}.{target}"
  5391. else:
  5392. op_name = "unknown_op"
  5393. return op_name
  5394. def codegen_size_asserts(self, wrapper: PythonWrapperCodegen) -> None:
  5395. if config.size_asserts and not V.graph.cpp_wrapper:
  5396. # comparing strides for 0 size tensor is tricky. Ignore them for now.
  5397. if sympy_product(self.get_size()) == 0:
  5398. return
  5399. size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
  5400. stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
  5401. op_name = self.get_op_name()
  5402. wrapper.writeline(
  5403. f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
  5404. )
  5405. def codegen_alignment_asserts(self, wrapper: PythonWrapperCodegen) -> None:
  5406. if config.alignment_asserts and not V.graph.cpp_wrapper:
  5407. name = self.get_name()
  5408. aligned = name not in V.graph.unaligned_buffers
  5409. op_name = self.get_op_name()
  5410. if aligned:
  5411. wrapper.writeline(
  5412. f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
  5413. )
  5414. else:
  5415. wrapper.writeline(
  5416. f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
  5417. )
  5418. def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None:
  5419. """
  5420. Track outputs of fallback operators if config.test_configs.track_memory_lifecycle
  5421. """
  5422. if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper:
  5423. return
  5424. wrapper.write_memory_track_allocation_once()
  5425. name = self.get_name()
  5426. wrapper.writeline(f"track_tensor({name}, '{name}')")
  5427. def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]:
  5428. """
  5429. get output sizes and strides, for template_codegen
  5430. """
  5431. _size = self.get_size()
  5432. _stride = self.get_stride()
  5433. # iter_ranges = _size of output tensor, reduce_range = [] because no reduction
  5434. return [_size, []], _stride
  5435. def canonicalize(self) -> tuple[Expr, Sequence[Expr]]:
  5436. """
  5437. Manually get canonicalization of the output index
  5438. """
  5439. # manually generate index formula for conv
  5440. sizevars = V.graph.sizevars
  5441. sizes = self.get_size()
  5442. strides = self.get_stride()
  5443. strides = [sizevars.size_hint(x) for x in strides]
  5444. # TODO: I can't tell if the symbols here are temporary
  5445. index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))]
  5446. # reorder index vars according to stride
  5447. index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
  5448. lookup = {pos: idx for idx, pos in enumerate(index_order)}
  5449. order = [lookup[i] for i in range(len(lookup))]
  5450. index_vars = [index_vars[i] for i in order]
  5451. indexer = self.make_indexer()
  5452. index = indexer(index_vars)
  5453. new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops(
  5454. index_vars, sizes, [index]
  5455. )
  5456. # assign new variables each dimension to deal with numbering mismatches
  5457. # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
  5458. _, add_var = var_builder("c")
  5459. replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
  5460. index = sympy_subs(sympy.expand(index), replacement)
  5461. return index, tuple(new_sizes)
  5462. @cache_on_self_and_args("ExternKernel")
  5463. def get_free_symbol_uses(
  5464. self, unbacked_only: bool = False
  5465. ) -> OrderedSet[sympy.Symbol]:
  5466. # NB: It's not necessary to check regular inputs as we automatically
  5467. # have dependencies on them
  5468. maybe_get_symbols = (
  5469. maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols
  5470. )
  5471. r = InputsKernel.get_free_symbol_uses(self, unbacked_only)
  5472. for arg in self.constant_args:
  5473. r |= maybe_get_symbols(arg)
  5474. for arg in self.kwargs.values():
  5475. r |= maybe_get_symbols(arg)
  5476. return r
  5477. def __str__(self) -> str:
  5478. kernel_name = getattr(self, "python_kernel_name", None)
  5479. lines = [
  5480. f"python_kernel_name={kernel_name!r}",
  5481. ]
  5482. lines += [
  5483. f"{field.name}={getattr(self, field.name)}"
  5484. for field in dataclasses.fields(self)
  5485. ]
  5486. lines.append(f"origin_node={self.origin_node!r}")
  5487. return self.str_helper(lines)
  5488. __repr__ = __str__
  5489. @ir_dataclass(frozen=False)
  5490. class ExternKernelOut(ExternKernel):
  5491. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5492. wrapper.generate_extern_kernel_out(self)
  5493. def __init__(
  5494. self,
  5495. layout: Layout,
  5496. inputs: Sequence[IRNode],
  5497. constant_args: Sequence[Any] = (),
  5498. kwargs: Optional[dict[str, Any]] = None,
  5499. output_view: Optional[ReinterpretView] = None,
  5500. python_kernel_name: Optional[str] = None,
  5501. cpp_kernel_name: Optional[str] = None,
  5502. ordered_kwargs_for_cpp_kernel: Sequence[Any] = (),
  5503. op_overload: Optional[_OpOverloads] = None,
  5504. ) -> None:
  5505. unwrapped_inputs = self.unwrap_storage(inputs)
  5506. assert isinstance(unwrapped_inputs, Sequence), type(unwrapped_inputs)
  5507. super().__init__(
  5508. None,
  5509. layout,
  5510. unwrapped_inputs,
  5511. constant_args,
  5512. kwargs or {},
  5513. None,
  5514. python_kernel_name,
  5515. cpp_kernel_name,
  5516. ordered_kwargs_for_cpp_kernel,
  5517. op_overload,
  5518. )
  5519. self.name = V.graph.register_buffer(self)
  5520. V.graph.register_operation(self)
  5521. def should_allocate(self) -> bool:
  5522. return True
  5523. class RandomSeeds(ExternKernelOut):
  5524. def __init__(self, count: int, device: torch.device) -> None:
  5525. limits = torch.iinfo(torch.int64)
  5526. super().__init__(
  5527. layout=FixedLayout(
  5528. device=device,
  5529. dtype=torch.int64,
  5530. size=[count],
  5531. ),
  5532. inputs=[],
  5533. constant_args=[limits.min, limits.max, [count]],
  5534. python_kernel_name="aten.randint.low_out",
  5535. # FIXME: Ideally we should only use at::_ops::randint_low_out::call here,
  5536. # but the signature is different from is at::randint_out. Again,
  5537. # we can simplify the code when only keeping an ABI-compatible version.
  5538. cpp_kernel_name="at::_ops::randint_low_out::call",
  5539. op_overload=aten.randint.low_out,
  5540. )
  5541. class ExternKernelAlloc(ExternKernel):
  5542. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5543. wrapper.generate_extern_kernel_alloc(self)
  5544. def __init__(
  5545. self,
  5546. layout: OutputSpec,
  5547. inputs: Sequence[IRNode],
  5548. constant_args: Sequence[Any] = (),
  5549. kwargs: Optional[dict[str, Any]] = None,
  5550. python_kernel_name: Optional[str] = None,
  5551. cpp_kernel_name: Optional[str] = None,
  5552. ordered_kwargs_for_cpp_kernel: Sequence[Any] = (),
  5553. op_overload: Optional[_OpOverloads] = None,
  5554. ) -> None:
  5555. unwrapped_inputs = self.unwrap_storage(inputs)
  5556. assert all(isinstance(i, IRNode) for i in unwrapped_inputs)
  5557. super().__init__(
  5558. None,
  5559. layout,
  5560. cast(Sequence[IRNode], unwrapped_inputs),
  5561. constant_args,
  5562. kwargs or {},
  5563. None,
  5564. python_kernel_name,
  5565. cpp_kernel_name,
  5566. ordered_kwargs_for_cpp_kernel,
  5567. op_overload,
  5568. )
  5569. # We need output buffers for generating kernel arguments in the
  5570. # abi-compatible mode, where we retrieve outputs by pass each individual
  5571. # output through the abi-compatible interface.
  5572. self.outputs: Sequence[Any] = []
  5573. self.name = V.graph.register_buffer(self)
  5574. V.graph.register_operation(self)
  5575. def should_allocate(self) -> bool:
  5576. return False
  5577. def apply_constraint(self) -> None:
  5578. raise NotImplementedError
  5579. class MutationOutput(Buffer):
  5580. """
  5581. An output buffer that represents the mutation of a pre-existing buffer
  5582. """
  5583. def __init__(
  5584. self, layout: OutputSpec, mutated_node: IRNode, mutating_node: Operation
  5585. ) -> None:
  5586. super().__init__(name=None, layout=layout)
  5587. mutated_node_name = mutated_node.get_name()
  5588. V.graph.mark_buffer_mutated(mutated_node_name)
  5589. self.mutation_names = [mutated_node_name]
  5590. self.mutating_node: Operation = mutating_node
  5591. self.name = V.graph.register_buffer(self)
  5592. def get_defining_op(self) -> Operation:
  5593. return self.mutating_node
  5594. def get_mutation_names(self) -> Sequence[str]:
  5595. return self.mutation_names
  5596. def should_allocate(self) -> bool:
  5597. return False
  5598. def get_mutation_buffers(self) -> Sequence[IRNode]:
  5599. mutation_names = self.get_mutation_names()
  5600. return [
  5601. buf
  5602. for buf in (V.graph.try_get_buffer(name) for name in mutation_names)
  5603. if buf is not None
  5604. ]
  5605. class TMADescriptor(ExternKernel):
  5606. """
  5607. An IR node representing a generic host-side TMA descriptor in the Triton API
  5608. Mostly useful for user-defined Triton kernels relying on host-side TMA;
  5609. but can, in principle, be used for Inductor's Triton templates, too.
  5610. See TMADescriptorExperimental and TMADescriptorStable for the two implementations
  5611. (the old API and the new API)
  5612. """
  5613. # as TMA descriptors are immutable,
  5614. # we can dedup them by the input args
  5615. _CACHE: dict[Any, TMADescriptor] = {}
  5616. @classmethod
  5617. def _create_impl(
  5618. cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]]
  5619. ) -> TMADescriptor:
  5620. assert len(tma_meta) == 2
  5621. if tma_meta[0] == "experimental":
  5622. return TMADescriptorExperimental(tensor, *tma_meta[1])
  5623. else:
  5624. assert tma_meta[0] == "stable"
  5625. return TMADescriptorStable(tensor, *tma_meta[1])
  5626. @classmethod
  5627. def create(
  5628. cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]]
  5629. ) -> TMADescriptor:
  5630. key = (id(tensor), tma_meta)
  5631. if key not in cls._CACHE:
  5632. cls._CACHE[key] = cls._create_impl(tensor, tma_meta)
  5633. return cls._CACHE[key]
  5634. def __init__(
  5635. self, tensor: IRNode, inputs: Sequence[Any], constant_args: Sequence[Any]
  5636. ) -> None:
  5637. super().__init__(
  5638. None,
  5639. # link back to the underlying tensor in terms of ownership
  5640. # to avoid getting the underlying tensor deleted *before*
  5641. # the TMADescriptor node can be deleted.
  5642. NonOwningLayout(
  5643. ReinterpretView(
  5644. data=tensor,
  5645. layout=tensor.get_layout(),
  5646. )
  5647. ),
  5648. cast(Sequence[Buffer], inputs),
  5649. tuple(constant_args),
  5650. None,
  5651. )
  5652. self.tensor = tensor
  5653. self.name = V.graph.register_buffer(self)
  5654. V.graph.register_operation(self)
  5655. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5656. wrapper.generate_tma_descriptor(self)
  5657. def get_tensor(self) -> IRNode:
  5658. return self.tensor
  5659. class TMADescriptorExperimental(TMADescriptor):
  5660. """
  5661. the new host-side TMA Descriptor API:
  5662. (the ones obtained via create_{1d,2d}_tma_descriptor calls).
  5663. See also TMADescriptorStable for the new API.
  5664. """
  5665. def __init__(
  5666. self,
  5667. tensor: IRNode,
  5668. dims: list[Union[int, torch.SymInt]],
  5669. block_dims: list[Union[int, torch.SymInt]],
  5670. element_size: Optional[int] = None,
  5671. ) -> None:
  5672. assert len(dims) in (1, 2)
  5673. assert len(dims) == len(block_dims)
  5674. if element_size is None:
  5675. element_size = tensor.get_dtype().itemsize
  5676. self.dims = dims
  5677. self.block_dims = block_dims
  5678. self.element_size = element_size
  5679. self.rank = len(self.dims)
  5680. inputs = [tensor]
  5681. constant_args = [
  5682. *self.dims,
  5683. *self.block_dims,
  5684. self.element_size,
  5685. ]
  5686. super().__init__(
  5687. tensor=tensor,
  5688. inputs=inputs,
  5689. constant_args=constant_args,
  5690. )
  5691. class TMADescriptorStable(TMADescriptor):
  5692. """
  5693. the new host-side TMA descriptor API
  5694. (the ones obtained via TensorDescriptor.from_tensor).
  5695. See also TMADescriptorExperimental for the old API.
  5696. """
  5697. def __init__(self, tensor: IRNode, block_shape: list[Union[int, torch.SymInt]]):
  5698. self.block_shape = block_shape
  5699. super().__init__(
  5700. tensor=tensor,
  5701. inputs=[tensor],
  5702. constant_args=block_shape,
  5703. )
  5704. class SubgraphBuffer(ExternKernel):
  5705. def __init__(
  5706. self,
  5707. layout: Layout,
  5708. input_nodes: list[Buffer],
  5709. gm: torch.fx.GraphModule,
  5710. example_inputs: list[Any],
  5711. subgraph_name: str,
  5712. ):
  5713. super().__init__(None, layout, input_nodes)
  5714. self.gm = gm
  5715. self.example_inputs = example_inputs
  5716. self.name = V.graph.register_buffer(self)
  5717. V.graph.register_operation(self)
  5718. self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name)
  5719. assert is_node_sequence(self.inputs)
  5720. sym_inputs = get_symbolic_inputs(self.inputs)
  5721. for sym_inp in sym_inputs:
  5722. self.subgraph.graph_inputs[sym_inp.name] = sym_inp
  5723. self.subgraph.graph_input_names.append(sym_inp.name)
  5724. self.sym_inputs = [sym_var.name for sym_var in sym_inputs]
  5725. import torch._inductor.config as inductor_config
  5726. with V.set_graph_handler(self.subgraph):
  5727. # Don't bother autotuning on Triton here
  5728. with inductor_config.patch(
  5729. max_autotune=False,
  5730. max_autotune_gemm=False,
  5731. max_autotune_gemm_backends="ATEN",
  5732. ):
  5733. self.subgraph.run(*self.example_inputs)
  5734. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5735. class CodegenGraph:
  5736. def __init__(self, graph: GraphLowering):
  5737. self.graph = graph
  5738. self.name = graph.name
  5739. assert is_node_sequence(self.inputs)
  5740. outer_inputs = [t.codegen_reference() for t in self.inputs]
  5741. wrapper.codegen_subgraph_with_flattened_outputs(
  5742. CodegenGraph(self.subgraph),
  5743. [*self.sym_inputs, *outer_inputs],
  5744. [self.name],
  5745. )
  5746. class UserDefinedTritonKernel(ExternKernel):
  5747. def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]:
  5748. from triton.runtime.autotuner import Autotuner
  5749. from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
  5750. kernel = kernel_side_table.get_kernel(self.kernel_idx)
  5751. configs = []
  5752. restore_value_args: list[str] = []
  5753. reset_to_zero_args: list[str] = []
  5754. if isinstance(kernel, Autotuner):
  5755. # https://github.com/triton-lang/triton/pull/5083
  5756. # changes kernel.restore_idx to kernel.restore_value
  5757. if hasattr(kernel, "restore_idx"):
  5758. restore_value_args.extend(
  5759. kernel.fn.arg_names[i] for i in kernel.restore_idx
  5760. )
  5761. else:
  5762. assert hasattr(kernel, "restore_value")
  5763. restore_value_args.extend(kernel.restore_value)
  5764. if hasattr(kernel, "reset_idx"):
  5765. for i in kernel.reset_idx:
  5766. reset_to_zero_args.append(kernel.fn.arg_names[i])
  5767. else:
  5768. assert hasattr(kernel, "reset_to_zero")
  5769. reset_to_zero_args.extend(kernel.reset_to_zero)
  5770. configs = kernel.configs
  5771. kernel = kernel.fn
  5772. return kernel, configs, restore_value_args, reset_to_zero_args
  5773. @override
  5774. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5775. """Overrides the parent member.
  5776. See https://github.com/pytorch/pytorch/issues/151692"""
  5777. from torch._inductor.utils import triton_version_uses_attrs_dict
  5778. (
  5779. kernel,
  5780. configs,
  5781. restore_value_args,
  5782. reset_to_zero_args,
  5783. ) = self.get_kernel_and_metadata()
  5784. # Definition of kernel
  5785. (
  5786. new_name,
  5787. triton_meta,
  5788. extra_launch_args,
  5789. ) = wrapper.define_user_defined_triton_kernel(
  5790. kernel,
  5791. configs,
  5792. self.kwargs,
  5793. restore_value_args,
  5794. reset_to_zero_args,
  5795. self.grid,
  5796. )
  5797. named_args = {
  5798. k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
  5799. }
  5800. assert hasattr(kernel, "arg_names") and hasattr(kernel, "constexprs"), type(
  5801. kernel
  5802. )
  5803. constexpr_names = OrderedSet(kernel.arg_names[i] for i in kernel.constexprs)
  5804. args: list[Any] = []
  5805. arg_types: list[Any] = []
  5806. raw_keys_filtered: list[Any] = []
  5807. raw_args_filtered: list[Any] = []
  5808. for name, arg in itertools.chain(
  5809. named_args.items(), zip(itertools.repeat(""), extra_launch_args)
  5810. ):
  5811. if name in constexpr_names and triton_version_uses_attrs_dict():
  5812. # see #160000 - we don't pass in constexpr args to speed up runtime.
  5813. continue
  5814. raw_keys_filtered.append(name)
  5815. raw_args_filtered.append(arg)
  5816. if isinstance(arg, IRNode):
  5817. args.append(arg.codegen_reference())
  5818. arg_types.append(arg.get_dtype())
  5819. elif isinstance(arg, (int, float, bool, sympy.Expr)):
  5820. args.append(arg)
  5821. arg_types.append(type(arg))
  5822. elif name in constexpr_names:
  5823. # insert a dummy value for constexpr args of unsupported type
  5824. # constexprs will end up getting baked into the kernel at compile time
  5825. args.append(-1)
  5826. arg_types.append(int)
  5827. elif arg is None:
  5828. """
  5829. Filter out None args.
  5830. see https://github.com/pytorch/pytorch/issues/115344
  5831. Two cases for a None arg:
  5832. 1. The arg is already tl.constexpr, so leave it in
  5833. 2. The arg is not tl.constexpr so we have to remove it
  5834. """
  5835. if triton_version_uses_attrs_dict():
  5836. args.append(-1)
  5837. arg_types.append(int)
  5838. else:
  5839. raw_keys_filtered.pop()
  5840. raw_args_filtered.pop()
  5841. else:
  5842. raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}")
  5843. self.codegen_comment(wrapper)
  5844. wrapper.generate_kernel_call(
  5845. new_name,
  5846. args,
  5847. arg_types=arg_types,
  5848. raw_args=raw_args_filtered,
  5849. raw_keys=raw_keys_filtered,
  5850. triton_meta=triton_meta,
  5851. triton=True,
  5852. device=self.get_device(),
  5853. original_fxnode_name=self.fx_node.name,
  5854. )
  5855. @cache_on_self_and_args("UserDefinedTritonKernel")
  5856. def get_free_symbol_uses(
  5857. self, unbacked_only: bool = False
  5858. ) -> OrderedSet[sympy.Symbol]:
  5859. # add unbacked symbols used in the grid to the ones used
  5860. # in the kwargs (the latter is generated by ExternKernel)
  5861. return super().get_free_symbol_uses(unbacked_only) | get_free_symbols(
  5862. self.grid, unbacked_only
  5863. )
  5864. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  5865. return OrderedSet()
  5866. def __init__(
  5867. self,
  5868. *,
  5869. kernel_idx: int,
  5870. grid: Any,
  5871. tma_descriptor_metadata: dict[str, Any],
  5872. kernel_args: dict[str, Any],
  5873. ) -> None:
  5874. inputs: list[IRNode] = []
  5875. kwargs: dict[str, IRNode] = {}
  5876. constant_args: list[IRNode] = []
  5877. for k, v in kernel_args.items():
  5878. if isinstance(v, TensorBox):
  5879. t = InputsKernel.unwrap_storage_for_input(self.realize_input(v))
  5880. if k in tma_descriptor_metadata:
  5881. t = TMADescriptor.create(t, tma_descriptor_metadata[k])
  5882. inputs.append(t)
  5883. kwargs[k] = t
  5884. else:
  5885. constant_args.append(v)
  5886. kwargs[k] = v
  5887. assert len(inputs) != 0
  5888. self.device = inputs[0].get_device()
  5889. assert isinstance(inputs, Sequence), type(inputs)
  5890. super().__init__(
  5891. None,
  5892. NoneLayout(device=self.device),
  5893. inputs,
  5894. tuple(constant_args),
  5895. kwargs,
  5896. )
  5897. self.kernel_idx = kernel_idx
  5898. self.grid = grid
  5899. kernel, configs, _, _ = self.get_kernel_and_metadata()
  5900. # If we are autotuning, not all arguments will be passed
  5901. assert hasattr(kernel, "arg_names")
  5902. self.ordered_kwargs_for_cpp_kernel = [
  5903. arg for arg in kernel.arg_names if arg in kernel_args
  5904. ]
  5905. from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
  5906. autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {}
  5907. self.mutable_args = [
  5908. kernel_args[key]
  5909. for key in identify_mutated_tensors(
  5910. kernel, {**kernel_args, **autotuned_kwargs}, tma_descriptor_metadata
  5911. )
  5912. ]
  5913. self.mutation_outputs = [
  5914. MutationOutput(NoneLayout(device=self.device), buf, self)
  5915. for buf in self.mutable_args
  5916. ]
  5917. V.graph.register_operation(self)
  5918. def get_outputs(self) -> list[Buffer]:
  5919. return list(self.mutation_outputs)
  5920. def get_device(self) -> Optional[torch.device]:
  5921. return self.device
  5922. class InplaceBernoulliFallback(ExternKernel):
  5923. """
  5924. This needs to be a custom class to handle mutation properly
  5925. """
  5926. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5927. assert all(isinstance(t, IRNode) for t in self.inputs)
  5928. (x,) = (cast(IRNode, t).codegen_reference() for t in self.inputs)
  5929. if V.graph.cpp_wrapper:
  5930. # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here,
  5931. # which needs to be explicitly generated for cpp wrapper
  5932. wrapper.writeline(
  5933. f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}"
  5934. )
  5935. else:
  5936. wrapper.writeline(
  5937. f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
  5938. )
  5939. def should_allocate(self) -> bool:
  5940. return False
  5941. def get_mutation_names(self) -> Sequence[str]:
  5942. return [self.input_name(0)]
  5943. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  5944. return OrderedSet()
  5945. def __init__(
  5946. self, op_overload: _OpOverloads, x: IRNode, *constant_args: Any
  5947. ) -> None:
  5948. super().__init__(
  5949. None,
  5950. NoneLayout(device=x.get_device()),
  5951. self.unwrap_storage([x]),
  5952. constant_args,
  5953. op_overload=op_overload,
  5954. )
  5955. V.graph.mark_buffer_mutated(x.get_name())
  5956. self.name = V.graph.register_buffer(self)
  5957. V.graph.register_operation(self)
  5958. # Used to deal with torch.complex types
  5959. class InplaceCopyFallback(ExternKernel):
  5960. """
  5961. This needs to be a custom class to handle mutation properly
  5962. """
  5963. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  5964. (dst, src, non_blocking) = self.codegen_args()
  5965. wrapper.codegen_device_copy(src, dst, non_blocking)
  5966. def should_allocate(self) -> bool:
  5967. return False
  5968. def get_mutation_names(self) -> Sequence[str]:
  5969. return [self.input_name(0)]
  5970. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  5971. return OrderedSet()
  5972. def __init__(
  5973. self,
  5974. layout: OutputSpec,
  5975. inputs: Sequence[IRNode],
  5976. constant_args: Sequence[Any],
  5977. ) -> None:
  5978. super().__init__(
  5979. None,
  5980. layout,
  5981. inputs,
  5982. constant_args,
  5983. python_kernel_name="aten.copy_",
  5984. cpp_kernel_name="aoti_torch_copy_",
  5985. )
  5986. V.graph.mark_buffer_mutated(inputs[0].get_name())
  5987. self.name = V.graph.register_buffer(self)
  5988. V.graph.register_operation(self)
  5989. @classmethod
  5990. def create(
  5991. cls, dst: IRNode, src: IRNode, non_blocking: bool = False
  5992. ) -> InplaceCopyFallback:
  5993. inputs = [cls.realize_input(t) for t in [dst, src]]
  5994. constant_args = (non_blocking,)
  5995. result = InplaceCopyFallback(
  5996. NoneLayout(device=dst.get_device()),
  5997. inputs,
  5998. constant_args,
  5999. )
  6000. return result
  6001. class MutatingFirstArgExternKernel(ExternKernel):
  6002. """
  6003. This needs to be a custom class to handle mutation properly
  6004. """
  6005. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6006. assert is_node_sequence(self.inputs)
  6007. argrefs = [
  6008. *(t.codegen_reference() for t in self.inputs),
  6009. *map(repr, self.constant_args),
  6010. ]
  6011. wrapper.writeline(
  6012. f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}"
  6013. )
  6014. def should_allocate(self) -> bool:
  6015. return False
  6016. def get_mutation_names(self) -> Sequence[str]:
  6017. return [self.input_name(0)]
  6018. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6019. return OrderedSet()
  6020. def has_side_effects(self) -> bool:
  6021. return True
  6022. class ResizeStorageBytes(MutatingFirstArgExternKernel):
  6023. def __init__(self, variable: IRNode, new_size: int) -> None:
  6024. assert isinstance(new_size, int), "TODO: dynamic shapes"
  6025. super().__init__(
  6026. None,
  6027. NoneLayout(device=variable.get_device()),
  6028. self.unwrap_storage([variable]),
  6029. constant_args=(new_size,),
  6030. )
  6031. V.graph.mark_buffer_mutated(variable.get_name())
  6032. self.name = V.graph.register_buffer(self)
  6033. V.graph.register_operation(self)
  6034. self.python_kernel_name = "inductor_ops.resize_storage_bytes_"
  6035. self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_"
  6036. assert isinstance(variable, (BaseView, StorageBox, TensorBox)), type(variable)
  6037. V.graph.never_reuse_buffers.add(variable.data.get_name())
  6038. class SetSourceTensorKernel(ExternKernelAlloc):
  6039. def __init__(self, self_tensor: IRNode, storage_tensor: IRNode) -> None:
  6040. storage_tensor.freeze_layout()
  6041. super().__init__(
  6042. storage_tensor.get_layout(),
  6043. [self_tensor, storage_tensor],
  6044. python_kernel_name="torch.ops.aten.set_.source_Tensor",
  6045. op_overload=torch.ops.aten.set_.source_Tensor,
  6046. )
  6047. assert isinstance(self_tensor, (BaseView, StorageBox, TensorBox)), type(
  6048. self_tensor
  6049. )
  6050. V.graph.never_reuse_buffers.add(self_tensor.data.get_name())
  6051. V.graph.never_reuse_buffers.add(storage_tensor.get_name())
  6052. V.graph.never_reuse_buffers.add(self.get_name())
  6053. device = storage_tensor.get_device()
  6054. self.mutation_outputs = [
  6055. MutationOutput(NoneLayout(device=device), self_tensor, self),
  6056. MutationOutput(NoneLayout(device=device), storage_tensor, self),
  6057. ]
  6058. def get_inputs_that_alias_output(self) -> Sequence[str]:
  6059. return [self.input_name(0), self.input_name(1)]
  6060. class ScatterFallback(ExternKernel):
  6061. """
  6062. This needs to be a custom class to handle mutation properly.
  6063. This class handles both aten.scatter_ and aten.scatter_reduce_.
  6064. It also handle the case `src` being a scalar properly.
  6065. """
  6066. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6067. reduce = self.kwargs["reduce"]
  6068. if V.graph.cpp_wrapper:
  6069. # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
  6070. get_operator_enum = {"add": "sum", "multiply": "prod"}
  6071. if reduce in get_operator_enum:
  6072. reduce = get_operator_enum[reduce]
  6073. assert is_node_sequence(self.inputs)
  6074. if self.src_is_tensor:
  6075. (x, index, src) = (t.codegen_reference() for t in self.inputs)
  6076. else:
  6077. (x, index) = (t.codegen_reference() for t in self.inputs)
  6078. src = self.constant_args[1]
  6079. wrapper.generate_scatter_fallback(
  6080. x,
  6081. [x, self.constant_args[0], index, src],
  6082. self.cpp_kernel_name,
  6083. self.python_kernel_name,
  6084. self.src_is_tensor,
  6085. reduce,
  6086. self.codegen_kwargs(),
  6087. )
  6088. def should_allocate(self) -> bool:
  6089. return False
  6090. def get_mutation_names(self) -> list[str]:
  6091. inp = self.inputs[0]
  6092. assert isinstance(inp, IRNode)
  6093. return [inp.get_name()]
  6094. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6095. return OrderedSet()
  6096. def __init__(
  6097. self,
  6098. op_overload: _OpOverloads,
  6099. x: IRNode,
  6100. dim: int,
  6101. index: IRNode,
  6102. src: IRNode,
  6103. *,
  6104. reduce: Optional[str] = None,
  6105. include_self: bool = True,
  6106. ) -> None:
  6107. self.src_is_tensor = isinstance(src, TensorBox)
  6108. constant_args: tuple[Any, ...]
  6109. if self.src_is_tensor:
  6110. tensors = [self.realize_input(t) for t in [x, index, src]]
  6111. constant_args = (dim,)
  6112. else:
  6113. tensors = [self.realize_input(t) for t in [x, index]]
  6114. constant_args = (dim, src)
  6115. super().__init__(
  6116. None,
  6117. NoneLayout(device=x.get_device()),
  6118. self.unwrap_storage(tensors),
  6119. constant_args,
  6120. {"reduce": reduce, "include_self": include_self},
  6121. python_kernel_name=str(op_overload),
  6122. ordered_kwargs_for_cpp_kernel=["reduce", "include_self"],
  6123. op_overload=op_overload,
  6124. )
  6125. V.graph.mark_buffer_mutated(x.get_name())
  6126. self.name = V.graph.register_buffer(self)
  6127. V.graph.register_operation(self)
  6128. class IndexPutFallback(ExternKernel):
  6129. """
  6130. This needs to be a custom class to handle mutation and indices properly
  6131. """
  6132. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6133. assert is_node_sequence(self.inputs)
  6134. (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs)
  6135. indices = []
  6136. iter_valid_indices = iter(valid_indices)
  6137. for i, _ in enumerate(self.indices):
  6138. if self.indices[i] is not None:
  6139. indices.append(next(iter_valid_indices))
  6140. else:
  6141. indices.append(V.graph.wrapper_code.none_str)
  6142. wrapper.generate_index_put_fallback(
  6143. self.get_kernel_name(), x, indices, values, *self.codegen_const_args()
  6144. )
  6145. def should_allocate(self) -> bool:
  6146. return False
  6147. def get_mutation_names(self) -> Sequence[str]:
  6148. return [self.input_name(0)]
  6149. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6150. return OrderedSet()
  6151. def __init__(
  6152. self,
  6153. op_overload: torch._ops.OpOverload,
  6154. x: IRNode,
  6155. indices: list[Any],
  6156. values: Sequence[Any],
  6157. accumulate: Any,
  6158. ) -> None:
  6159. self.indices = indices
  6160. valid_indices = [i for i in indices if i is not None]
  6161. tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
  6162. cpp_kernel_name = "aoti_torch_index_put_out"
  6163. super().__init__(
  6164. None,
  6165. NoneLayout(device=x.get_device()),
  6166. self.unwrap_storage(tensors),
  6167. (accumulate,),
  6168. python_kernel_name="aten.index_put_",
  6169. cpp_kernel_name=cpp_kernel_name,
  6170. op_overload=op_overload,
  6171. )
  6172. V.graph.mark_buffer_mutated(self.input_name(0))
  6173. self.name = V.graph.register_buffer(self)
  6174. V.graph.register_operation(self)
  6175. class DeviceCopy(ExternKernelOut):
  6176. @classmethod
  6177. def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode:
  6178. if (
  6179. not x.is_extern()
  6180. and all(r in V.graph.constants for r in x.get_read_names())
  6181. and not config.aot_inductor.use_runtime_constant_folding
  6182. ):
  6183. return x.constant_to_device(device)
  6184. V.graph.add_device_info(device)
  6185. x_device = x.get_device()
  6186. assert x_device is not None
  6187. V.graph.add_device_info(x_device)
  6188. developer_warning("DeviceCopy in input program")
  6189. constant_args = (non_blocking,)
  6190. # Device Copy should keep the same layout as input
  6191. x = ExternKernel.require_contiguous(x)
  6192. stride = None
  6193. if x.get_size():
  6194. # x.get_stride() may be unimplemented if x's size is empty
  6195. stride = x.get_stride()
  6196. is_destination_pinned = (
  6197. is_gpu(x_device.type) and device.type == "cpu" and non_blocking
  6198. )
  6199. is_source_pinned = (
  6200. x_device.type == "cpu" and is_gpu(device.type) and non_blocking
  6201. )
  6202. if is_source_pinned and is_storage_and_layout(x):
  6203. x.get_layout().is_pinned = True
  6204. return DeviceCopy(
  6205. FixedLayout(
  6206. device,
  6207. x.get_dtype(),
  6208. x.get_size(),
  6209. stride,
  6210. is_pinned=is_destination_pinned,
  6211. ),
  6212. [cls.realize_input(x)],
  6213. constant_args,
  6214. )
  6215. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6216. args = self.codegen_args()
  6217. assert len(args) == 2
  6218. if self.output_view:
  6219. wrapper.codegen_device_copy(
  6220. args[0], self.output_view.codegen_reference(), args[1]
  6221. )
  6222. else:
  6223. wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
  6224. class DynamicSelectStorageOffset(ExternKernel):
  6225. """
  6226. The result of computing a dynamic selection index is determined as follows: when the index in the
  6227. select operation is unbacked, the actual index calculation is ambiguous for negative indices
  6228. (index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked
  6229. SymInt to represent the storage offset and decompose the select operation into a call to as_strided,
  6230. computing the storage offset at runtime with this node.
  6231. """
  6232. def get_reads(self) -> OrderedSet[Dep]:
  6233. return OrderedSet()
  6234. def should_allocate(self) -> bool:
  6235. return False
  6236. def __init__(
  6237. self,
  6238. unbacked_offset_symbol: sympy.Symbol,
  6239. index: sympy.Symbol,
  6240. base_offset: Union[sympy.Symbol, int],
  6241. base_dim_stride: Union[sympy.Symbol, int],
  6242. size: Union[sympy.Symbol, int],
  6243. ) -> None:
  6244. super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
  6245. # This node codegen the following:
  6246. # unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size)
  6247. self.unbacked_offset_symbol = unbacked_offset_symbol
  6248. self.index = index
  6249. self.base_offset = base_offset
  6250. self.base_dim_stride = base_dim_stride
  6251. self.size = size
  6252. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6253. return OrderedSet([self.unbacked_offset_symbol])
  6254. @cache_on_self_and_args("DynamicSelectStorageOffset")
  6255. def get_free_symbol_uses(
  6256. self, unbacked_only: bool = False
  6257. ) -> OrderedSet[sympy.Symbol]:
  6258. return get_free_symbols(self.index, unbacked_only)
  6259. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6260. wrapper.codegen_dynamic_select_index(self)
  6261. class DynamicScalar(ExternKernel):
  6262. """
  6263. The result of a call to aten._local_scalar_dense.
  6264. """
  6265. def get_reads(self) -> OrderedSet[Dep]:
  6266. return OrderedSet()
  6267. def should_allocate(self) -> bool:
  6268. return False
  6269. def __init__(
  6270. self, sym: sympy.Symbol, keypath: pytree.KeyPath, data: IRNode
  6271. ) -> None:
  6272. data.realize()
  6273. super().__init__(
  6274. None, NoneLayout(device=torch.device("cpu")), self.unwrap_storage([data])
  6275. )
  6276. self.sym = sym
  6277. self.keypath = keypath
  6278. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  6279. return OrderedSet([self.sym])
  6280. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6281. wrapper.codegen_dynamic_scalar(self)
  6282. class AssertScalar(ExternKernel):
  6283. """
  6284. The result of a call to aten._assert_scalar
  6285. """
  6286. def get_reads(self) -> OrderedSet[Dep]:
  6287. return OrderedSet()
  6288. def should_allocate(self) -> bool:
  6289. return False
  6290. def __init__(self, scalar: SympyBoolean, msg: str) -> None:
  6291. super().__init__(
  6292. # Buffer(name, layotu)
  6293. None,
  6294. NoneLayout(device=torch.device("cpu")),
  6295. # InputsKernel(inputs)
  6296. [],
  6297. )
  6298. self.scalar = scalar
  6299. self.msg = msg
  6300. def has_side_effects(self) -> bool:
  6301. return True
  6302. @cache_on_self_and_args("AssertScalar")
  6303. def get_free_symbol_uses(
  6304. self, unbacked_only: bool = False
  6305. ) -> OrderedSet[sympy.Symbol]:
  6306. return get_free_symbols(self.scalar, unbacked_only)
  6307. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6308. if not config.scalar_asserts:
  6309. return
  6310. # NB: It is EXTREMELY important not to simplify the scalar under assertion here,
  6311. # because simplify is done with respect to runtime asserts. So if you have
  6312. # "u0 == 0" in the runtime asserts, if you subsequently try to
  6313. # simplify(u0 == 0), you will get True (because we've already runtime assert'ed
  6314. # that it's true). But we're code generating the actual runtime assert here!!
  6315. symbol = next(iter(self.get_free_symbol_uses(unbacked_only=False)))
  6316. if V.graph.fx_wrapper:
  6317. # TODO fix
  6318. pass
  6319. elif V.graph.cpp_wrapper:
  6320. symbol_str = f"std::to_string({symbol})"
  6321. sizevar = V.graph.wrapper_code.codegen_cpp_sizevar(
  6322. self.scalar, simplify=False
  6323. )
  6324. # TODO: when we start compiling in C++20, annotate with [[unlikely]].
  6325. wrapper.writeline(
  6326. f'if (!({sizevar})) {{ throw std::runtime_error("Expected {self.msg} but received " + {symbol_str}); }}'
  6327. )
  6328. else:
  6329. sizevar = V.graph.wrapper_code.codegen_python_sizevar(
  6330. self.scalar, simplify=False
  6331. )
  6332. wrapper.writeline(f"if not ({sizevar}):")
  6333. wrapper.writeline(f" raise RuntimeError({repr(self.msg)})")
  6334. # No one should ever use this buffer, but for uniformity
  6335. # define the variable and assign it None
  6336. wrapper.writeline(f"{self.get_name()} = None")
  6337. @ir_dataclass(frozen=False)
  6338. class ExternKernelNode:
  6339. name: str
  6340. node: export_schema.Node
  6341. class FallbackKernel(ExternKernelAlloc):
  6342. """
  6343. A class that represents a fallback kernel for handling operators that are not
  6344. directly support by inductor. It currently supports functional ops, view ops,
  6345. inplace aten ops, and mutating ops that are auto-functionalizable.
  6346. """
  6347. def __init__(
  6348. self,
  6349. layout: OutputSpec,
  6350. kernel: _OpOverloads,
  6351. tensor_args: Sequence[IRNode],
  6352. nontensor_args: Sequence[Any],
  6353. unflatten_args: Callable[..., Any],
  6354. kwargs: Optional[dict[str, Any]] = None,
  6355. *,
  6356. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  6357. ) -> None:
  6358. super().__init__(
  6359. layout,
  6360. tuple(tensor_args),
  6361. tuple(nontensor_args),
  6362. op_overload=kernel,
  6363. )
  6364. self.use_runtime_dispatch = False
  6365. self.unbacked_bindings = unbacked_bindings or {}
  6366. assert isinstance(
  6367. kernel, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
  6368. ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
  6369. self.op_overload = kernel
  6370. self.unflatten_args = unflatten_args
  6371. self.kwargs = {} if kwargs is None else kwargs
  6372. assert self.python_kernel_name is not None
  6373. V.graph.warn_fallback(self.python_kernel_name)
  6374. # args that are aliased
  6375. self.alias_names: list[str] = []
  6376. # args that are mutated AND returned from the op
  6377. self.mutation_names: list[str] = []
  6378. if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
  6379. # We assume here that HOPs with FallbackKernel are functional.
  6380. # This may not always be true! HOPs must individually opt-in to
  6381. # FallbackKernel, so please check this if you opt-in.
  6382. return
  6383. if "_c10d_functional" in self.op_overload.name():
  6384. # _c10d_functional kernels are lowered into _CollectiveKernel which
  6385. # derives from FallbackKernel for the cpp codegen. The kernels
  6386. # don't pass the can_auto_functionalize check, but their mutation
  6387. # is handled properly by _CollectiveKernel.
  6388. return
  6389. schema = self.op_overload._schema
  6390. # NOTE: [FallbackKernel supported operators]
  6391. # We only support three types of operators:
  6392. # - functional ops
  6393. # - view ops
  6394. # - inplace aten ops
  6395. # - mutating ops that are auto-functionalizable. That is,
  6396. # the operator may mutate any number of inputs, but its outputs
  6397. # may not alias any of the inputs.
  6398. #
  6399. # The unsupported cases usually do not show up here (because
  6400. # AOTAutograd functionalized them away); the only way for an in-place
  6401. # op to show up here is if a lowering or pass introduced it.
  6402. if torch._library.utils.mutates_and_returns_first_arg(self.op_overload):
  6403. self.mutation_names.append(tensor_args[0].get_name())
  6404. return
  6405. if schema.is_mutable and not can_auto_functionalize(kernel):
  6406. raise NotImplementedError(
  6407. f"NYI: Can't generate FallbackKernel for {kernel}"
  6408. )
  6409. args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
  6410. def handle_aliasing_and_mutation(info: torch._C.Argument, arg: Any) -> None:
  6411. # Assertions to make sure we didn't mismatch args
  6412. if isinstance(info.type, torch.ListType):
  6413. assert isinstance(arg, (list, tuple)), type(arg)
  6414. if library_utils.is_tensor_like_type(info.type):
  6415. # PyTorch also accepts None and scalar types for args marked as "Tensor".
  6416. # We're not going to check all of them here.
  6417. assert not isinstance(arg, (tuple, list))
  6418. if arg is None:
  6419. return
  6420. if info.alias_info is None:
  6421. return
  6422. def add_alias(t: IRNode) -> None:
  6423. self.alias_names.append(t.get_name())
  6424. assert info.alias_info is not None
  6425. if info.alias_info.is_write:
  6426. self.mutation_outputs.append(
  6427. MutationOutput(NoneLayout(device=t.get_device()), t, self)
  6428. )
  6429. if library_utils.is_tensorlist_like_type(info.type):
  6430. if arg is not None:
  6431. for optional_tensor_arg in arg:
  6432. add_alias(optional_tensor_arg)
  6433. else:
  6434. assert library_utils.is_tensor_like_type(info.type)
  6435. add_alias(arg)
  6436. for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
  6437. handle_aliasing_and_mutation(info, arg)
  6438. def get_read_writes(self) -> dependencies.ReadWrites:
  6439. read_writes = super().get_read_writes()
  6440. if self.op_overload is torch._prims.rng_prims.graphsafe_run_with_rng_state:
  6441. for arg in self.constant_args:
  6442. if isinstance(arg, GeneratorState):
  6443. read_writes = read_writes.with_read(
  6444. dependencies.StarDep(arg.get_name())
  6445. )
  6446. return read_writes
  6447. def codegen_unbacked_symbol_defs(self, wrapper: PythonWrapperCodegen) -> None:
  6448. return wrapper.codegen_unbacked_symbol_defs_for_outputs(
  6449. self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None)
  6450. )
  6451. def get_unbacked_symbol_defs(self) -> Container[sympy.Symbol]: # type: ignore[override]
  6452. if unbacked_bindings := getattr(self, "unbacked_bindings", None):
  6453. resolved = resolve_unbacked_bindings(
  6454. V.graph.sizevars.shape_env, unbacked_bindings
  6455. )
  6456. assert resolved is not None
  6457. return resolved.keys()
  6458. else:
  6459. return OrderedSet()
  6460. def codegen_args(self) -> list[str]:
  6461. @dataclasses.dataclass
  6462. class Shim:
  6463. ref: Any
  6464. def __repr__(self) -> str:
  6465. return self.ref
  6466. assert is_node_sequence(self.inputs)
  6467. tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
  6468. args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
  6469. if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
  6470. args = self.fill_non_provided_args(args, kwargs)
  6471. args = [
  6472. V.graph.wrapper_code.val_to_arg_str(x, param.real_type)
  6473. for param, x in zip(self.op_overload._schema.arguments, args)
  6474. ]
  6475. else:
  6476. args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]
  6477. # let self.codegen_kwargs handle kwargs
  6478. self.kwargs.update(kwargs)
  6479. return args
  6480. @staticmethod
  6481. def find_device(
  6482. tensor_args: Optional[Sequence[torch.Tensor]], example_output: Sequence[Any]
  6483. ) -> Any:
  6484. non_torch_bind_tensor_args = (
  6485. [t for t in tensor_args if not isinstance(t, TorchBindObject)]
  6486. if tensor_args
  6487. else None
  6488. )
  6489. if non_torch_bind_tensor_args:
  6490. assert tensor_args
  6491. devices = [arg.get_device() for arg in tensor_args if arg.get_device()]
  6492. return devices[0]
  6493. if isinstance(example_output, torch.Tensor):
  6494. return example_output.device
  6495. if isinstance(example_output, (list, tuple)):
  6496. device_set = OrderedSet(
  6497. FallbackKernel.find_device(None, x) for x in example_output
  6498. )
  6499. # Remove None
  6500. devices = [device for device in device_set if device]
  6501. if len(devices) == 1:
  6502. return devices[0]
  6503. for device in devices:
  6504. assert isinstance(device, torch.device)
  6505. if is_gpu(device.type):
  6506. return device
  6507. return devices[0]
  6508. return None
  6509. def has_side_effects(self) -> bool:
  6510. if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
  6511. return False
  6512. return get_schema_info(self.op_overload).is_mutable()
  6513. def get_inputs_that_alias_output(self) -> Sequence[str]:
  6514. assert isinstance(
  6515. self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
  6516. ), (
  6517. f"Fails to create FallbackKernel for {self.op_overload}: "
  6518. f"{type(self.op_overload)} not supported"
  6519. )
  6520. # See [Note: FallbackKernel supported operators]: for a mutating
  6521. # op that is auto-functionalizable, its outputs does NOT
  6522. # alias any of the inputs.
  6523. if (
  6524. not isinstance(self.op_overload, torch._ops.HigherOrderOperator)
  6525. and "_c10d_functional" not in self.op_overload.name()
  6526. and self.op_overload._schema.is_mutable
  6527. and can_auto_functionalize(self.op_overload)
  6528. ):
  6529. return []
  6530. else:
  6531. return self.alias_names
  6532. def get_mutation_names(self) -> Sequence[str]:
  6533. assert len(self.mutation_names) <= 1
  6534. return self.mutation_names
  6535. def export_extern_kernel_node(self): # type: ignore[no-untyped-def]
  6536. """
  6537. ProxyExecutor Design Note
  6538. We export the ExternFallbackNodes (for custom ops) into a serialized file
  6539. and run it with a host side proxy executor to address the ABI problem
  6540. This is currently only implemented for fbcode. Eventually, we will also make this work for OSS.
  6541. Detailed design doc can be found at
  6542. https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing
  6543. """
  6544. log.debug(
  6545. "Extern kernel node added for node %s with target %s.",
  6546. self.get_name(),
  6547. self.op_overload,
  6548. )
  6549. assert isinstance(self, FallbackKernel), type(self)
  6550. args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
  6551. args = self.fill_non_provided_args(args, kwargs)
  6552. ordered_kwargs = [
  6553. self.get_kwargs_value(key, **kwargs)
  6554. for key in self.ordered_kwargs_for_cpp_kernel
  6555. ]
  6556. target = self.op_overload
  6557. if not V.graph.aot_mode:
  6558. # No need to serialize in the cpp wrapper JIT mode
  6559. return [*args, *ordered_kwargs]
  6560. serializer = GraphModuleSerializer(None, []) # type: ignore[arg-type]
  6561. named_arguments = serializer.serialize_inputs(target, args, kwargs)
  6562. # serialize_outputs
  6563. def handle_single_output(
  6564. return_type: Union[torch.TensorType, torch.ListType, torch.JitType],
  6565. output: Union[IRNode, Sequence[IRNode]],
  6566. ) -> export_schema.Argument:
  6567. if isinstance(return_type, (torch.TensorType, torch.NoneType)):
  6568. # For single Tensor or None
  6569. out = output
  6570. if isinstance(output, (list, tuple)):
  6571. assert len(output) == 1
  6572. out = output[0]
  6573. if isinstance(return_type, torch.TensorType):
  6574. assert isinstance(out, IRNode)
  6575. return export_schema.Argument.create(
  6576. as_tensor=export_schema.TensorArgument(name=out.get_name())
  6577. )
  6578. else: # NoneType
  6579. assert out is None
  6580. return export_schema.Argument.create(as_none=True)
  6581. elif isinstance(return_type, torch.ListType) and isinstance(
  6582. return_type.getElementType(), torch.TensorType
  6583. ):
  6584. assert isinstance(output, Sequence), type(output)
  6585. # For single TensorList
  6586. return export_schema.Argument.create(
  6587. as_tensors=[
  6588. export_schema.TensorArgument(name=out.get_name())
  6589. for out in output
  6590. ]
  6591. )
  6592. elif isinstance(return_type, torch.OptionalType) and isinstance(
  6593. return_type.getElementType(), torch.TensorType
  6594. ):
  6595. # For OptionalTensor
  6596. if output is None:
  6597. return export_schema.Argument.create(
  6598. as_optional_tensor=export_schema.OptionalTensorArgument.create(
  6599. as_none=True
  6600. )
  6601. )
  6602. else:
  6603. assert isinstance(output, IRNode)
  6604. return export_schema.Argument.create(
  6605. as_optional_tensor=export_schema.OptionalTensorArgument.create(
  6606. as_tensor=export_schema.TensorArgument(
  6607. name=output.get_name()
  6608. )
  6609. )
  6610. )
  6611. elif isinstance(return_type, torch.IntType):
  6612. return export_schema.Argument.create(as_int=output)
  6613. else:
  6614. raise RuntimeError(f"Unsupported return type {type(return_type)}")
  6615. if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind):
  6616. returns = target.schema(args[0], args[1]).returns
  6617. else:
  6618. returns = target._schema.returns # type: ignore[union-attr]
  6619. if len(returns) == 1:
  6620. # NOTE: [special handling of all_reduce_coalesced_'s return value]
  6621. # all_reduce_coalesced_ return a list of tensors via self.mutation_outputs
  6622. outputs = self.outputs if self.outputs else self.mutation_outputs
  6623. return_type = returns[0].real_type
  6624. output_arguments = [handle_single_output(return_type, outputs)]
  6625. else:
  6626. # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])"
  6627. # Not generating output args for self.mutation_outputs
  6628. output_arguments = [
  6629. handle_single_output(
  6630. return_schema.real_type, # type: ignore[attr-defined]
  6631. output,
  6632. )
  6633. for return_schema, output in zip(returns, self.outputs)
  6634. ]
  6635. assert self.op_overload is not None
  6636. node = ExternKernelNode(
  6637. name=self.get_name(),
  6638. node=export_schema.Node(
  6639. target=self.op_overload.name(),
  6640. inputs=named_arguments,
  6641. outputs=output_arguments,
  6642. metadata={},
  6643. ),
  6644. )
  6645. V.extern_kernel_nodes.append(node)
  6646. return [*args, *ordered_kwargs]
  6647. @override
  6648. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6649. """Overrides the parent member.
  6650. See https://github.com/pytorch/pytorch/issues/151692"""
  6651. kernel = self.op_overload
  6652. assert kernel is not None
  6653. if kernel.namespace == "aten":
  6654. # Aten Fallback Ops
  6655. assert isinstance(kernel, torch._ops.OpOverload), type(kernel)
  6656. if V.graph.cpp_wrapper:
  6657. from torchgen.aoti.fallback_ops import inductor_fallback_ops
  6658. if str(kernel) not in inductor_fallback_ops:
  6659. # C shim v2 is torchgen-ed, which should cover all aten ops.
  6660. # If you do hit a missed op, please update fallback_ops.py.
  6661. log.warning(
  6662. "%s is missing a c-shim implementation, using proxy executor as fallback",
  6663. kernel,
  6664. )
  6665. self.use_runtime_dispatch = True
  6666. elif kernel.namespace == "_quantized":
  6667. # Internal Quantized Fallback Ops
  6668. assert isinstance(kernel, torch._ops.OpOverload), type(kernel)
  6669. elif V.graph.cpp_wrapper:
  6670. # For non-aten OpOverload, i.e. custom ops
  6671. # If the op is in custom_ops_to_c_shims, generate direct function call
  6672. self.use_runtime_dispatch = (
  6673. kernel not in config.aot_inductor.custom_ops_to_c_shims
  6674. )
  6675. # Handle the special case where a complex number is input to a C-shim kernel for
  6676. # a scalar input. The torchgen'ed shim API will use type "double", which is
  6677. # incompatible with complex numbers, forcing a fallback to runtime dispatch.
  6678. if (
  6679. V.graph.cpp_wrapper
  6680. and isinstance(kernel, torch._ops.OpOverload)
  6681. and not self.use_runtime_dispatch
  6682. ):
  6683. def is_number(t: torch.JitType) -> bool:
  6684. if isinstance(t, torch.OptionalType):
  6685. return is_number(t.getElementType())
  6686. return isinstance(t, torch.NumberType)
  6687. # Using unflatten_args is a bit of a hack, but all the complex arguments we
  6688. # care about are in self.constant_args, and calling unflatten_args puts them
  6689. # in the correct order without triggering codegen.
  6690. args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
  6691. # Append kwarg values to args. ordered_kwargs_for_cpp_kernel is guaranteed
  6692. # to be set, since this is an OpOverload kernel.
  6693. args_iter = itertools.chain(
  6694. args,
  6695. (
  6696. self.get_kwargs_value(k, **kwargs)
  6697. for k in self.ordered_kwargs_for_cpp_kernel
  6698. ),
  6699. )
  6700. self.use_runtime_dispatch = any(
  6701. isinstance(v, complex) and is_number(a.real_type)
  6702. for v, a in zip(args_iter, kernel._schema.arguments)
  6703. )
  6704. self.codegen_comment(wrapper)
  6705. if self.use_runtime_dispatch:
  6706. exported_args = self.export_extern_kernel_node()
  6707. assert self.python_kernel_name is not None
  6708. assert self.op_overload is not None
  6709. wrapper.generate_fallback_kernel_with_runtime_lookup(
  6710. self.get_name(),
  6711. self.python_kernel_name,
  6712. lambda: [*self.codegen_args(), *self.codegen_kwargs()],
  6713. self.op_overload,
  6714. exported_args,
  6715. # NOTE: [special handling of all_reduce_coalesced_'s return value]
  6716. self.outputs if self.outputs else self.mutation_outputs,
  6717. )
  6718. else:
  6719. wrapper.generate_fallback_kernel(self)
  6720. if isinstance(self.layout, Layout):
  6721. self.codegen_size_asserts(wrapper)
  6722. self.codegen_alignment_asserts(wrapper)
  6723. self.codegen_memory_tracking(wrapper)
  6724. self.codegen_unbacked_symbol_defs(wrapper)
  6725. @staticmethod
  6726. def tensor_to_layout(output: torch.Tensor) -> FixedLayout:
  6727. is_pinned = False
  6728. try:
  6729. is_pinned = output.is_pinned()
  6730. except RuntimeError:
  6731. # dispatch not implemented
  6732. pass
  6733. return FixedLayout(
  6734. output.device,
  6735. output.dtype,
  6736. convert_shape_to_inductor(output.size()),
  6737. convert_shape_to_inductor(output.stride()),
  6738. is_pinned=is_pinned,
  6739. )
  6740. @classmethod
  6741. def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKernel:
  6742. """Create an instance of FallbackKernel from an _OpOverloads"""
  6743. fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
  6744. if kernel not in fake_incorrect_kernels:
  6745. context = cast(AbstractContextManager[None], V.graph.fake_mode)
  6746. else:
  6747. context = nullcontext()
  6748. with context:
  6749. (
  6750. example_output,
  6751. tensor_args,
  6752. non_tensor_args,
  6753. unflatten_args,
  6754. unbacked_bindings,
  6755. ) = cls.process_kernel(kernel, *args, **kwargs)
  6756. # We need this extra check for input alignment since the example
  6757. # inputs we created are always aligned.
  6758. has_unaligned_input = any(is_unaligned(arg) for arg in tensor_args)
  6759. device = cls.find_device(tensor_args, example_output)
  6760. if not device and isinstance(
  6761. kernel, torch._higher_order_ops.torchbind.CallTorchBind
  6762. ):
  6763. # use CPU device for torchbind methods that don't take in or output any tensor, e.g. size()
  6764. device = torch.device("cpu")
  6765. if example_output is None:
  6766. packed = cls(
  6767. NoneLayout(device=device),
  6768. kernel,
  6769. tensor_args,
  6770. non_tensor_args,
  6771. unflatten_args,
  6772. unbacked_bindings=unbacked_bindings,
  6773. )
  6774. else:
  6775. assert device, "Not sure where to find device info"
  6776. packed = cls(
  6777. MultiOutputLayout(device=device),
  6778. kernel,
  6779. tensor_args,
  6780. non_tensor_args,
  6781. unflatten_args,
  6782. unbacked_bindings=unbacked_bindings,
  6783. )
  6784. def generate_output(output: Any, indices: list[tuple[Any, int]]) -> Any:
  6785. if isinstance(output, (list, tuple)):
  6786. return type(output)(
  6787. generate_output(output[i], indices + [(type(output), i)])
  6788. for i in range(len(output))
  6789. )
  6790. elif isinstance(output, dict):
  6791. return {
  6792. key: generate_output(val, indices + [(type(output), key)])
  6793. for key, val in output.items()
  6794. }
  6795. elif isinstance(output, torch.Tensor):
  6796. buf = MultiOutput(
  6797. cls.tensor_to_layout(output),
  6798. packed,
  6799. indices,
  6800. )
  6801. if (
  6802. config.assume_unaligned_fallback_output
  6803. or has_unaligned_input
  6804. or not tensor_is_aligned(output)
  6805. ):
  6806. V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
  6807. return buf
  6808. elif isinstance(output, int):
  6809. return output
  6810. elif isinstance(output, torch.SymInt):
  6811. return output.node.expr
  6812. else:
  6813. assert output is None, (
  6814. f"FallbackKernel output type {type(output)} is not supported"
  6815. )
  6816. return None
  6817. outputs = generate_output(example_output, [])
  6818. if isinstance(outputs, (list, tuple)):
  6819. packed.outputs = outputs
  6820. elif isinstance(outputs, dict):
  6821. packed.outputs = tuple(outputs)
  6822. else:
  6823. packed.outputs = [outputs]
  6824. return outputs
  6825. def apply_constraint(self) -> None:
  6826. return super().apply_constraint()
  6827. @ir_dataclass(frozen=False)
  6828. class ComplexView(FallbackKernel):
  6829. """View a complex number as two dtyped numbers or vice versa"""
  6830. def should_allocate(self) -> bool:
  6831. return False
  6832. def get_inputs_that_alias_output(self) -> Sequence[str]:
  6833. # Signal to codegen that our output buffer isn't safe to reuse
  6834. return [self.input_name(0)]
  6835. def __init__(
  6836. self,
  6837. layout: OutputSpec,
  6838. kernel: _OpOverloads,
  6839. tensor_args: Sequence[IRNode],
  6840. nontensor_args: Sequence[Any],
  6841. unflatten_args: Callable[..., Any],
  6842. *,
  6843. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  6844. ) -> None:
  6845. super().__init__(
  6846. layout,
  6847. kernel,
  6848. tensor_args,
  6849. nontensor_args,
  6850. unflatten_args,
  6851. unbacked_bindings=unbacked_bindings,
  6852. )
  6853. class MemoryCheckKernel(FallbackKernel):
  6854. """
  6855. Custom kernel for memory checking that generates direct function calls
  6856. TODO - the custom op was erroring with str inputs. should be able to custom op directly.
  6857. """
  6858. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6859. """Override codegen to write direct function call"""
  6860. # Extract our arguments from nontensor_args
  6861. wrapper.write_memory_track_allocation_once()
  6862. alive_list, dead_list, is_final_step = self.constant_args
  6863. alive_repr = repr(alive_list)
  6864. dead_repr = repr(dead_list)
  6865. if is_final_step:
  6866. wrapper.writeline(
  6867. "# note: dont currently distinguish between buffers returned and dealloc'd in last step"
  6868. )
  6869. call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})"
  6870. else:
  6871. call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})"
  6872. wrapper.writeline(call)
  6873. @ir_dataclass
  6874. class MultiOutputLayout(OutputSpec):
  6875. device: torch.device
  6876. def get_device(self) -> Optional[torch.device]:
  6877. return self.device
  6878. class MultiOutput(ExternKernel):
  6879. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  6880. wrapper.codegen_multi_output(self)
  6881. if not self.skip_size_stride_alignment_checks:
  6882. self.codegen_size_asserts(wrapper)
  6883. self.codegen_alignment_asserts(wrapper)
  6884. def __init__(
  6885. self,
  6886. layout: OutputSpec,
  6887. input: IRNode,
  6888. indices: list[tuple[Any, ...]],
  6889. skip_size_stride_alignment_checks: bool = False,
  6890. ) -> None:
  6891. super().__init__(None, layout, [input], ())
  6892. self.name = V.graph.register_buffer(self)
  6893. V.graph.register_operation(self)
  6894. self.indices = indices
  6895. self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
  6896. @cache_on_self_and_args("MultiOutput")
  6897. def get_free_symbol_uses(
  6898. self, unbacked_only: bool = False
  6899. ) -> OrderedSet[sympy.Symbol]:
  6900. input_node = self.inputs[0]
  6901. assert isinstance(input_node, IRNode), input_node
  6902. return input_node.get_free_symbol_uses(unbacked_only)
  6903. def should_allocate(self) -> bool:
  6904. return len(self.inputs) == 1 and (
  6905. isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM
  6906. )
  6907. def get_inputs_that_alias_output(self) -> Sequence[str]:
  6908. return [
  6909. inp.get_name()
  6910. for inp in self.inputs
  6911. if isinstance(inp, FallbackKernel)
  6912. and len(inp.get_inputs_that_alias_output()) > 0
  6913. ]
  6914. # We just use a normal dataclass for MutableBox/TensorBox/StorageBox since
  6915. # they're mainly lowering-time constructs that we expect to mutate and such.
  6916. @dataclasses.dataclass
  6917. class MutableBox(IRNode):
  6918. """
  6919. TensorBox / StorageBox allow in-place mutation of Tensors
  6920. """
  6921. data: IRNode
  6922. def has_exceeded_max_reads(self) -> bool:
  6923. return self.data.has_exceeded_max_reads()
  6924. def get_device(self) -> Optional[torch.device]:
  6925. return self.data.get_device()
  6926. def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
  6927. return self.data.make_loader()
  6928. def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
  6929. return self.data.make_indexer()
  6930. def get_stride(self) -> Sequence[_IntLike]:
  6931. return self.data.get_stride()
  6932. def get_name(self) -> str:
  6933. return self.data.get_name()
  6934. def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
  6935. return self.data.has_large_inner_fn(threshold)
  6936. def mark_reuse(self, users: int) -> None:
  6937. return self.data.mark_reuse(users)
  6938. def realize_hint(self) -> None:
  6939. return self.data.realize_hint()
  6940. def unwrap_view(self) -> IRNode:
  6941. return self.data.unwrap_view()
  6942. def is_input_buffer(self) -> bool:
  6943. return self.data.is_input_buffer()
  6944. def freeze_layout(self) -> None:
  6945. return self.data.freeze_layout()
  6946. def freeze_layout_with_stride_order(
  6947. self, order: Sequence[int], allow_padding: bool = False
  6948. ) -> None:
  6949. return self.data.freeze_layout_with_stride_order(order, allow_padding)
  6950. def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
  6951. return self.data.freeze_layout_with_fill_order(order)
  6952. def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None:
  6953. return self.data.freeze_layout_with_same_order(stride)
  6954. def freeze_layout_with_exact_strides(
  6955. self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
  6956. ) -> None:
  6957. return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding)
  6958. def get_read_writes(self) -> dependencies.ReadWrites:
  6959. return self.data.get_read_writes()
  6960. def get_reads(self) -> OrderedSet[Dep]:
  6961. return self.data.get_reads()
  6962. def num_reads(self) -> int:
  6963. return self.data.num_reads()
  6964. def get_storage_numel(self) -> _IntLike:
  6965. return self.data.get_storage_numel()
  6966. def get_reduction_type(self) -> Optional[str]:
  6967. return self.data.get_reduction_type()
  6968. def get_reduction_size(self) -> Sequence[Expr]:
  6969. return self.data.get_reduction_size()
  6970. def is_extern(self) -> bool:
  6971. return self.data.is_extern()
  6972. def is_no_op(self) -> bool:
  6973. return self.data.is_no_op()
  6974. def constant_to_device(self, device: torch.device) -> IRNode:
  6975. return self.data.constant_to_device(device)
  6976. def get_mutation_names(self) -> Sequence[str]:
  6977. return self.data.get_mutation_names()
  6978. def get_operation_name(self) -> str:
  6979. return self.data.get_operation_name()
  6980. def get_inputs_that_alias_output(self) -> Sequence[str]:
  6981. return self.data.get_inputs_that_alias_output()
  6982. def realize(self) -> Optional[str]:
  6983. return self.data.realize()
  6984. @cache_on_self_and_args("MutableBox")
  6985. def get_free_symbol_uses(
  6986. self, unbacked_only: bool = False
  6987. ) -> OrderedSet[sympy.Symbol]:
  6988. return self.data.get_free_symbol_uses(unbacked_only)
  6989. def get_read_names(self) -> OrderedSet[str]:
  6990. return self.data.get_read_names()
  6991. def get_defining_op(self) -> Optional[Operation]:
  6992. return self.data.get_defining_op()
  6993. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  6994. return self.data.codegen_reference(writer)
  6995. @property
  6996. def layout(self) -> OutputSpec:
  6997. # we intentionally call get_output_spec (rather than get_layout) since Buffer.layout is an OutputSpec
  6998. return self.data.get_output_spec()
  6999. def get_layout(self) -> Layout:
  7000. return self.data.get_layout()
  7001. def get_output_spec(self) -> OutputSpec:
  7002. return self.data.get_output_spec()
  7003. def get_size(self) -> Sequence[Expr]:
  7004. return self.data.get_size()
  7005. @property
  7006. def dtype(self) -> torch.dtype:
  7007. return self.data.dtype
  7008. def __str__(self) -> str:
  7009. if isinstance(self.data, MutableBox):
  7010. line0 = f"{type(self).__name__}({type(self.data).__name__}("
  7011. endl = "))"
  7012. inner = self.data.data
  7013. else:
  7014. line0 = f"{type(self).__name__}("
  7015. inner = self.data
  7016. endl = ")"
  7017. lines = [
  7018. line0,
  7019. indent(str(inner)),
  7020. endl,
  7021. ]
  7022. return "\n".join(lines)
  7023. __repr__ = __str__
  7024. class TensorBox(MutableBox):
  7025. @staticmethod
  7026. def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]:
  7027. if isinstance(data, ShapeAsConstantBuffer):
  7028. return data
  7029. return TensorBox(StorageBox(data))
  7030. class StorageBox(MutableBox):
  7031. """
  7032. StorageBox allow in-place mutation of Tensors
  7033. """
  7034. def is_input_buffer(self) -> bool:
  7035. if isinstance(self.data, (InputBuffer, ReinterpretView)):
  7036. return self.data.get_name() in V.graph.graph_inputs
  7037. return False
  7038. def is_module_buffer(self) -> bool:
  7039. return (
  7040. isinstance(self.data, (ConstantBuffer))
  7041. and self.data.get_name() in V.graph.constants
  7042. )
  7043. def realize(self) -> Optional[str]:
  7044. if IRNode.is_realized_node(self.data):
  7045. return self.data.get_name()
  7046. assert isinstance(self.data, (Pointwise, Reduction, Scan, Sort)), type(
  7047. self.data
  7048. )
  7049. origin_node = self.data.get_origin_node()
  7050. traceback = self.data.get_traceback()
  7051. device = self.data.get_device()
  7052. assert device is not None
  7053. self.data = ComputedBuffer(
  7054. name=None,
  7055. layout=FlexibleLayout(
  7056. device=device,
  7057. dtype=self.data.get_dtype(),
  7058. size=self.data.get_size(),
  7059. is_pinned=False,
  7060. ),
  7061. data=self.data,
  7062. )
  7063. self.data.name = V.graph.register_buffer(self.data)
  7064. V.graph.register_operation(self.data)
  7065. self.data.origins = self.origins
  7066. self.data.origin_node = origin_node
  7067. self.data.traceback = traceback
  7068. return self.data.name
  7069. def realize_hint(self) -> None:
  7070. """
  7071. Called on buffers we expect to be forced to realize later.
  7072. """
  7073. if (
  7074. isinstance(self.data, (Pointwise, Reduction))
  7075. and self.data.inner_fn_opcount().nontrivial_read_count > 1
  7076. ):
  7077. self.realize()
  7078. def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool:
  7079. return (
  7080. sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) > threshold
  7081. )
  7082. def has_exceeded_max_reads(self) -> bool:
  7083. return isinstance(self.data, Pointwise) and (
  7084. self.num_reads() > config.realize_acc_reads_threshold
  7085. or self.has_large_inner_fn()
  7086. or (
  7087. config.realize_acc_reads_size_threshold is not None
  7088. and self.has_accumulated_enough_reads_by_size(
  7089. config.realize_acc_reads_size_threshold
  7090. )
  7091. )
  7092. )
  7093. def should_realize_on_reuse(self, users: int) -> bool:
  7094. """
  7095. A heuristic to decide if we should realize a tensor
  7096. that is used multiple times.
  7097. """
  7098. if users > 1 and isinstance(self.data, (Pointwise, Reduction)):
  7099. if is_cpu(self.data):
  7100. # Heuristic for realizing reused result of heavy ops on cpu
  7101. opcount = self.data.inner_fn_opcount()
  7102. heavy_ops = ["exp", "sigmoid"] # a list of heavy ops
  7103. if any(x in opcount.used_ops for x in heavy_ops):
  7104. return True
  7105. return (
  7106. self.num_reads() > config.realize_reads_threshold
  7107. or self.has_large_inner_fn()
  7108. )
  7109. return False
  7110. def mark_reuse(self, users: int) -> None:
  7111. if self.should_realize_on_reuse(users):
  7112. self.realize()
  7113. def num_reads(self) -> int:
  7114. return self.data.num_reads()
  7115. @ir_dataclass(frozen=False)
  7116. class Subgraph(IRNode):
  7117. name: str
  7118. graph_module: torch.fx.GraphModule
  7119. graph: Optional[GraphLowering] = None
  7120. def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
  7121. buffers = [
  7122. buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer
  7123. for buffer in buffers
  7124. ]
  7125. # assuming the same buffer is represented by the same IRNode object
  7126. return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers)
  7127. @ir_dataclass(frozen=False)
  7128. class InvokeSubgraph(ExternKernel):
  7129. """
  7130. Ir node for the invoke_subgraph HOP.
  7131. """
  7132. subgraph: Optional[Subgraph] = None
  7133. operands: Optional[Sequence[IRNode]] = None
  7134. outputs: Optional[Sequence[IRNode]] = None
  7135. def __init__(
  7136. self, subgraph: Subgraph, operands: Sequence[IRNode], layout: MultiOutputLayout
  7137. ) -> None:
  7138. super().__init__(
  7139. name=None,
  7140. layout=layout,
  7141. inputs=operands,
  7142. )
  7143. self.subgraph = subgraph
  7144. self.name = V.graph.register_buffer(self)
  7145. V.graph.register_operation(self)
  7146. @classmethod
  7147. def create(
  7148. cls, subgraph: Subgraph, *operands: IRNode
  7149. ) -> list[Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]]:
  7150. """For each operand, get a realized input, force it to have the same
  7151. strides as the subgraph inputs, then use an InvokeSubgraph"""
  7152. from .lowering import constrain_to_fake_tensor
  7153. # TODO(anijain2305) - Support sym expr as operands in future.
  7154. current_node = V.graph.current_node
  7155. fake_operands = None
  7156. if eager_input_vals := current_node.meta.get("eager_input_vals"):
  7157. # eager_input_vals is (args_values, kwargs_values). We need args for invoke_subgraph
  7158. fake_operands = eager_input_vals[0][2:]
  7159. else:
  7160. # For the partitioned backward graph, we do not have
  7161. # eager_input_vals. Here, we rely on the recorded example values.
  7162. fx_operands = current_node.args[2:]
  7163. fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]
  7164. # Realize the inputs. Also intermediates can have different strides than
  7165. # the inputs of the subgraph. So, force the intermediates to have same
  7166. # strides as that of subgraph inputs.
  7167. operands: list[IRNode] = [cls.realize_input(x) for x in operands]
  7168. new_operands: list[IRNode] = []
  7169. for idx, operand in enumerate(operands):
  7170. if isinstance(operand, (ShapeAsConstantBuffer, GeneratorState)):
  7171. new_operands.append(operand)
  7172. else:
  7173. new_operands.append(
  7174. constrain_to_fake_tensor(operand, fake_operands[idx])
  7175. )
  7176. operands = new_operands
  7177. if subgraph.graph is None:
  7178. # create and lower subgraphs
  7179. subgraph.graph = V.graph.make_subgraph(
  7180. gm=subgraph.graph_module,
  7181. example_inputs=fake_operands,
  7182. subgraph_name=subgraph.name,
  7183. )
  7184. with V.set_graph_handler(subgraph.graph):
  7185. subgraph.graph.run(*fake_operands)
  7186. outputs = subgraph.graph.graph_outputs
  7187. # Find the device - operands could be integers from shapes, so we can't
  7188. # use operands[0]
  7189. device = None
  7190. for operand in operands:
  7191. if not isinstance(operand, ShapeAsConstantBuffer):
  7192. device = operand.get_device()
  7193. break
  7194. assert device is not None
  7195. invoke_subgraph = InvokeSubgraph(
  7196. subgraph=subgraph,
  7197. operands=operands,
  7198. layout=MultiOutputLayout(device=device),
  7199. )
  7200. def create_output(
  7201. output: IRNode, ind: int
  7202. ) -> Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]:
  7203. if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)):
  7204. return output
  7205. else:
  7206. device = output.get_device()
  7207. assert device is not None
  7208. return MultiOutput(
  7209. FixedLayout(
  7210. device=device,
  7211. dtype=output.get_dtype(),
  7212. size=output.get_size(),
  7213. stride=output.get_stride(),
  7214. offset=output.get_layout().offset,
  7215. is_pinned=output.get_layout().is_pinned,
  7216. ),
  7217. invoke_subgraph, # type: ignore[has-type]
  7218. [(list, ind)],
  7219. skip_size_stride_alignment_checks=True,
  7220. )
  7221. outs = [create_output(output, i) for i, output in enumerate(outputs)]
  7222. invoke_subgraph.outputs = outs # type: ignore[assignment]
  7223. return outs
  7224. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7225. wrapper.codegen_invoke_subgraph(self)
  7226. @ir_dataclass(frozen=False)
  7227. class Conditional(ExternKernel):
  7228. predicate: Optional[IRNode] = None
  7229. operands: Optional[Sequence[IRNode]] = None
  7230. true_subgraph: Optional[Subgraph] = None
  7231. false_subgraph: Optional[Subgraph] = None
  7232. outputs: Optional[Sequence[MultiOutput]] = None
  7233. def __init__(
  7234. self,
  7235. predicate: IRNode,
  7236. operands: Sequence[IRNode],
  7237. true_subgraph: Subgraph,
  7238. false_subgraph: Subgraph,
  7239. layout: MultiOutputLayout,
  7240. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
  7241. ) -> None:
  7242. self.predicate = predicate
  7243. self.operands = operands
  7244. self.true_subgraph = true_subgraph
  7245. self.false_subgraph = false_subgraph
  7246. sym_args, tensor_args = _split_by_sym_type([predicate, *operands])
  7247. super().__init__(
  7248. name=None,
  7249. layout=layout,
  7250. inputs=tensor_args,
  7251. constant_args=sym_args,
  7252. )
  7253. if unbacked_bindings is not None:
  7254. self.unbacked_bindings = unbacked_bindings
  7255. self.name = V.graph.register_buffer(self)
  7256. V.graph.register_operation(self)
  7257. @staticmethod
  7258. def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
  7259. if isinstance(s, int):
  7260. return s
  7261. return s.node.expr
  7262. @classmethod
  7263. def create(
  7264. cls,
  7265. predicate: TensorBox,
  7266. true_fn: Subgraph,
  7267. false_fn: Subgraph,
  7268. operands: list[Union[TensorBox, ShapeAsConstantBuffer]],
  7269. ) -> Sequence[IRNode]:
  7270. """Create a Sequence of IRNodes from a conditional statement (see .lowering.cond)"""
  7271. predicate = cls.realize_input(predicate)
  7272. operands = [cls.realize_input(x) for x in operands]
  7273. fx_operands: Argument = V.graph.current_node.args[-1]
  7274. assert isinstance(fx_operands, Sequence), type(fx_operands)
  7275. assert all(isinstance(n, Node) for n in fx_operands)
  7276. fake_operands = [cast(Node, x).meta["val"] for x in fx_operands]
  7277. for subgraph in (true_fn, false_fn):
  7278. if subgraph.graph is None:
  7279. # create and lower subgraphs
  7280. subgraph.graph = V.graph.make_subgraph(
  7281. gm=subgraph.graph_module,
  7282. example_inputs=fake_operands,
  7283. subgraph_name=subgraph.name,
  7284. )
  7285. with V.set_graph_handler(subgraph.graph):
  7286. subgraph.graph.run(*fake_operands)
  7287. assert true_fn.graph is not None
  7288. assert false_fn.graph is not None
  7289. true_outputs = true_fn.graph.graph_outputs
  7290. false_outputs = false_fn.graph.graph_outputs
  7291. for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)):
  7292. if _has_aliased_buffers(true_outputs):
  7293. raise AssertionError(
  7294. "Output aliasing is currently not supported in compiled torch.cond. "
  7295. f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}"
  7296. )
  7297. # make sure true and false outputs are structurally equivalent
  7298. assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs)
  7299. for i, (t_o, f_o) in enumerate(zip(true_outputs, false_outputs)):
  7300. assert t_o.get_device() == f_o.get_device(), (i, t_o, f_o)
  7301. assert t_o.get_dtype() == f_o.get_dtype(), (i, t_o, f_o)
  7302. assert t_o.get_layout().offset == f_o.get_layout().offset, (i, t_o, f_o)
  7303. device = next(
  7304. o.get_device()
  7305. for o in [predicate] + operands
  7306. if not isinstance(o, ShapeAsConstantBuffer)
  7307. )
  7308. unbacked_bindings = resolve_unbacked_bindings(
  7309. V.graph.sizevars.shape_env,
  7310. V.graph.current_node.meta.get("unbacked_bindings", None),
  7311. )
  7312. assert device is not None, "cannot determine device"
  7313. conditional = Conditional(
  7314. predicate=predicate,
  7315. operands=operands,
  7316. true_subgraph=true_fn,
  7317. false_subgraph=false_fn,
  7318. layout=MultiOutputLayout(device=device),
  7319. unbacked_bindings=unbacked_bindings,
  7320. )
  7321. outputs = [
  7322. MultiOutput(
  7323. FixedLayout(
  7324. device=device,
  7325. dtype=output.get_dtype(),
  7326. size=[Conditional._maybe_expr(sz) for sz in merged_output.size()],
  7327. stride=[
  7328. Conditional._maybe_expr(sz) for sz in merged_output.stride()
  7329. ],
  7330. offset=output.get_layout().offset,
  7331. is_pinned=output.get_layout().is_pinned,
  7332. ),
  7333. conditional,
  7334. [(list, i)],
  7335. )
  7336. # as the true and false outputs are equivalent,
  7337. # we can use either of them here as a "template"
  7338. for i, (output, merged_output) in enumerate(
  7339. zip(true_outputs, V.graph.current_node.meta["val"])
  7340. )
  7341. ]
  7342. conditional.outputs = outputs # type: ignore[assignment]
  7343. return outputs
  7344. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7345. wrapper.codegen_conditional(self)
  7346. wrapper.codegen_unbacked_symbol_defs_for_outputs(
  7347. self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
  7348. )
  7349. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  7350. if unbacked_bindings := getattr(self, "unbacked_bindings", None):
  7351. resolved = resolve_unbacked_bindings(
  7352. V.graph.sizevars.shape_env, unbacked_bindings
  7353. )
  7354. assert resolved is not None
  7355. return OrderedSet(resolved.keys())
  7356. else:
  7357. return OrderedSet()
  7358. def _split_by_sym_type(
  7359. args: list[Any],
  7360. ) -> tuple[list[ShapeAsConstantBuffer], list[Any]]:
  7361. non_sym_args = []
  7362. sym_args = []
  7363. for arg in args:
  7364. if isinstance(arg, ShapeAsConstantBuffer):
  7365. sym_args.append(arg.expr)
  7366. else:
  7367. non_sym_args.append(arg)
  7368. return sym_args, non_sym_args
  7369. @ir_dataclass(frozen=False)
  7370. class WhileLoop(ExternKernel):
  7371. """The IR node for while_loop and while_loop_stack_output. It supports input mutation."""
  7372. carried_inputs: Optional[Sequence[IRNode]] = None
  7373. additional_inputs: Optional[Sequence[IRNode]] = None
  7374. cond_subgraph: Optional[Subgraph] = None
  7375. body_subgraph: Optional[Subgraph] = None
  7376. outputs: Optional[Sequence[MultiOutput]] = None
  7377. def __init__(
  7378. self,
  7379. carried_inputs: Sequence[IRNode],
  7380. additional_inputs: Sequence[IRNode],
  7381. cond_subgraph: Subgraph,
  7382. body_subgraph: Subgraph,
  7383. layout: MultiOutputLayout,
  7384. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
  7385. stack_output: bool,
  7386. ) -> None:
  7387. self.carried_inputs = carried_inputs
  7388. self.additional_inputs = additional_inputs
  7389. self.cond_subgraph = cond_subgraph
  7390. self.body_subgraph = body_subgraph
  7391. sym_args, tensor_args = _split_by_sym_type(
  7392. [*carried_inputs, *additional_inputs]
  7393. )
  7394. super().__init__(
  7395. name=None,
  7396. layout=layout,
  7397. inputs=tensor_args,
  7398. constant_args=sym_args,
  7399. )
  7400. if unbacked_bindings is not None:
  7401. self.unbacked_bindings = unbacked_bindings
  7402. self.stack_output = stack_output
  7403. self.name = V.graph.register_buffer(self)
  7404. V.graph.register_operation(self)
  7405. # Accidental aliasing can be created due to cse, where the empty buffers we
  7406. # allocated for backward to use gets csed into the same buffer in function fx_graph_cse.
  7407. # See test_scan_multiple_layers_gradient for a concrete example.
  7408. @staticmethod
  7409. def _clone_aliased_inputs(carried_inputs: Sequence[IRNode]) -> Sequence[IRNode]:
  7410. if not _has_aliased_buffers(carried_inputs):
  7411. return carried_inputs
  7412. # Import clone from lowering module
  7413. from .lowering import clone
  7414. # Unwrap views to get the underlying buffers for comparison
  7415. unwrapped_buffers = [
  7416. buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer
  7417. for buffer in carried_inputs
  7418. ]
  7419. # Track which buffers we've seen and their indices
  7420. seen_buffers: OrderedSet[int] = OrderedSet()
  7421. result = []
  7422. for i, (original_input, unwrapped_buffer) in enumerate(
  7423. zip(carried_inputs, unwrapped_buffers)
  7424. ):
  7425. if id(unwrapped_buffer) in seen_buffers:
  7426. result.append(clone(original_input))
  7427. else:
  7428. seen_buffers.add(id(unwrapped_buffer))
  7429. result.append(original_input)
  7430. return result
  7431. @classmethod
  7432. def create(
  7433. cls,
  7434. cond_fn: Subgraph,
  7435. body_fn: Subgraph,
  7436. carried_inputs: Sequence[IRNode],
  7437. additional_inputs: Sequence[IRNode],
  7438. stack_output: bool,
  7439. ) -> Union[IRNode, Sequence[IRNode]]:
  7440. """create the while_loop IR node. stack_output controls whether it stack
  7441. each iterations' output, which is necessary for training.
  7442. """
  7443. from torch._higher_order_ops.utils import check_input_alias_and_mutation
  7444. def _require_exact_strides(
  7445. tensor_boxes: Sequence[IRNode],
  7446. fake_tensors: list[Union[int, torch.SymInt, torch.Tensor]],
  7447. ) -> list[IRNode]:
  7448. assert len(tensor_boxes) == len(fake_tensors)
  7449. ret = []
  7450. for tb, fk in zip(tensor_boxes, fake_tensors):
  7451. if isinstance(fk, torch.Tensor):
  7452. ret.append(
  7453. ExternKernel.require_exact_strides(
  7454. tb, fk.stride(), allow_padding=False
  7455. )
  7456. )
  7457. else:
  7458. ret.append(tb)
  7459. return ret
  7460. fx_carried_inputs = V.graph.current_node.args[-2]
  7461. fx_additional_inputs = V.graph.current_node.args[-1]
  7462. fx_all_inputs = fx_carried_inputs + fx_additional_inputs # type: ignore[operator]
  7463. fake_all_inputs = [x.meta["val"] for x in fx_all_inputs] # type: ignore[union-attr]
  7464. fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr]
  7465. fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr]
  7466. carried_inputs_ = [cls.realize_input(x) for x in carried_inputs]
  7467. carried_inputs_ = WhileLoop._clone_aliased_inputs(carried_inputs_)
  7468. carried_inputs_ = _require_exact_strides(carried_inputs_, fake_carried_inputs)
  7469. additional_inputs_ = [cls.realize_input(x) for x in additional_inputs]
  7470. additional_inputs_ = _require_exact_strides(
  7471. additional_inputs_, fake_additional_inputs
  7472. )
  7473. all_inputs = carried_inputs_ + additional_inputs_
  7474. for subgraph in (cond_fn, body_fn):
  7475. if subgraph.graph is None:
  7476. # create and lower subgraphs
  7477. assert isinstance(fx_all_inputs, Sequence), type(fx_all_inputs)
  7478. subgraph.graph = V.graph.make_subgraph(
  7479. gm=subgraph.graph_module,
  7480. example_inputs=fx_all_inputs, # type: ignore[arg-type]
  7481. subgraph_name=subgraph.name,
  7482. )
  7483. with V.set_graph_handler(subgraph.graph):
  7484. subgraph.graph.run(*fake_all_inputs)
  7485. # For body_fn, we require its output to have the exact same stride
  7486. # as inputs because the previous output is the input of next iteration.
  7487. #
  7488. # This cannot be automatically done in graph lowering because body_fn's graph outputs
  7489. # are not user-facing so the special handling for strides of user-facing output in graph
  7490. # lowering is not applicable.
  7491. if subgraph is body_fn:
  7492. assert len(subgraph.graph.graph_outputs) == len(
  7493. fake_carried_inputs
  7494. )
  7495. subgraph.graph.graph_outputs = _require_exact_strides( # type: ignore[assignment]
  7496. subgraph.graph.graph_outputs,
  7497. fake_carried_inputs,
  7498. )
  7499. assert cond_fn.graph and body_fn.graph
  7500. cond_outputs = cond_fn.graph.graph_outputs
  7501. body_outputs = body_fn.graph.graph_outputs
  7502. if _has_aliased_buffers(body_outputs):
  7503. raise AssertionError(
  7504. "Output aliasing is currently not supported in compiled torch.while_loop. "
  7505. f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}"
  7506. )
  7507. # make sure cond_fn returns a boolean scalar Tensor
  7508. assert len(cond_outputs) == 1, cond_outputs
  7509. p = cond_outputs[0]
  7510. if not isinstance(p, ShapeAsConstantBuffer):
  7511. assert p.get_dtype() == torch.bool, p
  7512. assert len(p.get_size()) == 0, p
  7513. assert len(all_inputs) > 0, (
  7514. "torch.while_loop is assumed to have at least one operand."
  7515. )
  7516. device = all_inputs[0].get_device()
  7517. assert device is not None # to make linter happy
  7518. # make sure carried_inputs_ and body outputs are structurally equivalent
  7519. assert len(carried_inputs_) == len(body_outputs), (
  7520. carried_inputs_,
  7521. body_outputs,
  7522. )
  7523. for i, (op, bo) in enumerate(zip(carried_inputs_, body_outputs)):
  7524. def _guard_list_equals(
  7525. lhs_exprs: Sequence[Union[int, sympy.Expr]],
  7526. rhs_exprs: Sequence[Union[int, sympy.Expr]],
  7527. ) -> None:
  7528. assert len(lhs_exprs) == len(rhs_exprs)
  7529. for lhs, rhs in zip(lhs_exprs, rhs_exprs):
  7530. V.graph.sizevars.check_equals(lhs, rhs)
  7531. _guard_list_equals(op.get_size(), bo.get_size())
  7532. _guard_list_equals(op.get_stride(), bo.get_stride())
  7533. # assume all carried_inputs_ and outputs are on the same device
  7534. # as the MultiOutputLayout below requires single device
  7535. assert op.get_device() == bo.get_device(), (i, op, bo, device)
  7536. assert op.get_dtype() == bo.get_dtype(), (i, op, bo)
  7537. assert device is not None
  7538. unbacked_bindings = resolve_unbacked_bindings(
  7539. V.graph.sizevars.shape_env,
  7540. V.graph.current_node.meta.get("unbacked_bindings", None),
  7541. )
  7542. while_loop = WhileLoop(
  7543. carried_inputs=carried_inputs_,
  7544. additional_inputs=additional_inputs_,
  7545. cond_subgraph=cond_fn,
  7546. body_subgraph=body_fn,
  7547. # asserted above that there is at least one operand
  7548. layout=MultiOutputLayout(device=device),
  7549. unbacked_bindings=unbacked_bindings,
  7550. stack_output=stack_output,
  7551. )
  7552. assert body_fn.graph is not None and isinstance(
  7553. body_fn.graph.module, torch.fx.GraphModule
  7554. ) # to make linter happy
  7555. # Handling input mutations
  7556. mutated_idxs = check_input_alias_and_mutation(
  7557. body_fn.graph.module, fake_all_inputs
  7558. )[3]
  7559. mutated_idx_set = OrderedSet(mutated_idxs)
  7560. mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set]
  7561. # Create all outputs first
  7562. mutated_inputs_iter = iter(mutated_inputs)
  7563. all_outputs: list[IRNode] = []
  7564. while_loop.outputs = []
  7565. while_loop.mutation_outputs = []
  7566. if stack_output:
  7567. assert len(mutated_idx_set) == 0, (
  7568. "NYI: while_loop_stack_output input mutations."
  7569. )
  7570. for idx, output in enumerate(V.graph.current_node.meta["val"]):
  7571. # Create MultiOutput for regular outputs
  7572. multi_out = MultiOutput(
  7573. FixedLayout(
  7574. device=output.device, # type: ignore[arg-type]
  7575. dtype=output.dtype,
  7576. size=[Conditional._maybe_expr(sz) for sz in output.size()],
  7577. stride=[Conditional._maybe_expr(st) for st in output.stride()],
  7578. ),
  7579. while_loop,
  7580. [(list, idx)],
  7581. )
  7582. while_loop.outputs.append(multi_out)
  7583. all_outputs.append(multi_out)
  7584. else:
  7585. for idx, output in enumerate(body_outputs):
  7586. if idx in mutated_idx_set:
  7587. assert idx < len(carried_inputs), "only carries can be mutated."
  7588. # Create MutationOutput for mutated inputs
  7589. mutated_input = next(mutated_inputs_iter)
  7590. while_loop.mutation_outputs.append(
  7591. MutationOutput(mutated_input.layout, mutated_input, while_loop) # type: ignore[attr-defined, union-attr]
  7592. )
  7593. all_outputs.append(mutated_input)
  7594. else:
  7595. multi_out = MultiOutput(
  7596. FixedLayout(
  7597. device=output.get_device(), # type: ignore[arg-type]
  7598. dtype=output.get_dtype(),
  7599. size=output.get_size(),
  7600. stride=output.get_stride(),
  7601. offset=output.get_layout().offset,
  7602. ),
  7603. while_loop,
  7604. [(list, idx)],
  7605. )
  7606. while_loop.outputs.append(multi_out)
  7607. all_outputs.append(multi_out)
  7608. for inp, out in zip(carried_inputs, all_outputs):
  7609. if inp.get_name() in V.graph.graph_inputs:
  7610. # if a carried input of the while_loop is a graph input,
  7611. # it can be returned as is when the number of iterations
  7612. # is zero. due to this, we can't (generally) reuse the
  7613. # output buffers corresponding to the graph inputs, as
  7614. # the inputs may end up being mutated.
  7615. V.graph.never_reuse_buffers.add(out.get_name())
  7616. return all_outputs
  7617. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7618. wrapper.codegen_while_loop(self, self.stack_output)
  7619. wrapper.codegen_unbacked_symbol_defs_for_outputs(
  7620. self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
  7621. )
  7622. def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
  7623. if unbacked_bindings := getattr(self, "unbacked_bindings", None):
  7624. resolved = resolve_unbacked_bindings(
  7625. V.graph.sizevars.shape_env, unbacked_bindings
  7626. )
  7627. assert resolved is not None
  7628. return OrderedSet(resolved.keys())
  7629. else:
  7630. return OrderedSet()
  7631. class EffectfulKernel(FallbackKernel):
  7632. def __init__(
  7633. self,
  7634. layout: OutputSpec,
  7635. kernel: _OpOverloads,
  7636. tensor_args: Sequence[IRNode],
  7637. nontensor_args: Sequence[Any],
  7638. unflatten_args: Callable[..., Any],
  7639. kwargs: Optional[dict[str, Any]] = None,
  7640. *,
  7641. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  7642. ) -> None:
  7643. super().__init__(
  7644. layout,
  7645. kernel,
  7646. tensor_args,
  7647. nontensor_args,
  7648. unflatten_args,
  7649. kwargs=None,
  7650. unbacked_bindings=unbacked_bindings,
  7651. )
  7652. from torch._higher_order_ops.effects import get_effect_key
  7653. uncovered_args = [
  7654. a.value if isinstance(a, TorchBindObject) else a for a in tensor_args
  7655. ]
  7656. effect_type = get_effect_key(kernel, (*nontensor_args, *uncovered_args), kwargs)
  7657. assert effect_type is not None
  7658. self.effect_type = effect_type
  7659. self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
  7660. V.graph.effectful_ops[effect_type] = self
  7661. def get_read_writes(self) -> dependencies.ReadWrites:
  7662. read_writes = super().get_read_writes()
  7663. if self.prev_effect_buffer is not None:
  7664. read_writes.reads.add(
  7665. dependencies.StarDep(self.prev_effect_buffer.get_name())
  7666. )
  7667. return read_writes
  7668. def has_side_effects(self) -> bool:
  7669. return True
  7670. class NonTensorObj(IRNode):
  7671. @cache_on_self_and_args("NonTensorObj")
  7672. def get_free_symbol_uses(
  7673. self, unbacked_only: bool = False
  7674. ) -> OrderedSet[sympy.Symbol]:
  7675. return OrderedSet()
  7676. @ir_dataclass
  7677. class TorchBindObject(NonTensorObj):
  7678. name: str
  7679. value: Union[FakeScriptObject, torch.ScriptObject]
  7680. def get_name(self) -> str:
  7681. return self.name
  7682. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  7683. return self.name
  7684. def get_value(self) -> Union[FakeScriptObject, torch.ScriptObject]:
  7685. return self.value
  7686. def get_real_obj(self) -> torch.ScriptObject:
  7687. if isinstance(self.value, torch.ScriptObject):
  7688. return self.value
  7689. else:
  7690. return self.value.real_obj
  7691. def get_buf_bytes(self) -> int:
  7692. # Returns the sum of all tensors in the flattened object
  7693. real_script_obj = self.get_real_obj()
  7694. assert hasattr(real_script_obj, "__obj_flatten__")
  7695. flat_dict = dict(real_script_obj.__obj_flatten__())
  7696. flat_elems = pytree.tree_flatten(flat_dict)[0]
  7697. flat_sizes = [
  7698. x.element_size() * x.numel()
  7699. for x in flat_elems
  7700. if isinstance(x, torch.Tensor)
  7701. ]
  7702. return functools.reduce(operator.add, flat_sizes, 0)
  7703. @ir_dataclass
  7704. class GeneratorState(NonTensorObj):
  7705. name: str
  7706. device: torch.device
  7707. def get_name(self) -> str:
  7708. return self.name
  7709. def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
  7710. return self.name
  7711. class _CollectiveKernel(FallbackKernel):
  7712. def should_allocate(self) -> bool:
  7713. return False
  7714. def has_side_effects(self) -> bool:
  7715. return True
  7716. # This is identical to FallbackKernel.set_cpp_kernel(), minus the
  7717. # part that checks against input aliasing and mutation.
  7718. def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
  7719. assert type(self.op_overload) is torch._ops.OpOverload, (
  7720. "Setting cpp kernel needs a valid op_overload"
  7721. )
  7722. kernel = self.op_overload
  7723. if cpp_kernel_name is not None:
  7724. self.cpp_kernel_name = cpp_kernel_name
  7725. else:
  7726. self.cpp_kernel_name = kernel._schema.name
  7727. self.ordered_kwargs_for_cpp_kernel = [
  7728. x.name for x in kernel._schema.arguments if x.kwarg_only
  7729. ]
  7730. # NOTE: [In-Place Collective Safety]
  7731. # Between the initiation and completion of an in-place collective, the
  7732. # input buffers are subject to both volatile reads and volatile writes.
  7733. # They must not be read, written to or reused by another kernel. To ensure
  7734. # the constraints, we model collective -> wait_tensor as as two-step
  7735. # mutation of the input buffers.
  7736. @classmethod
  7737. def create_inplace(
  7738. cls,
  7739. kernel: _OpOverloads,
  7740. inputs: Union[IRNode, list[IRNode]],
  7741. *args: Any,
  7742. **kwargs: Any,
  7743. ) -> None:
  7744. with V.graph.fake_mode:
  7745. (
  7746. _example_output,
  7747. tensor_args,
  7748. non_tensor_args,
  7749. unflatten_args,
  7750. unbacked_bindings,
  7751. ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
  7752. assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
  7753. for tensor_arg in tensor_args:
  7754. tensor_arg.realize()
  7755. device = tensor_args[0].get_device()
  7756. packed = cls(
  7757. NoneLayout(device=device),
  7758. kernel,
  7759. tensor_args,
  7760. non_tensor_args,
  7761. unflatten_args,
  7762. )
  7763. inps = pytree.tree_leaves(inputs)
  7764. packed.mutation_outputs.extend(
  7765. [MutationOutput(NoneLayout(device=device), buf, packed) for buf in inps]
  7766. )
  7767. # For inplace collective ops, the input is guaranteed to be alias of the returned value of op.
  7768. packed.alias_names.extend([inp.get_name() for inp in inps])
  7769. if "out" in kwargs:
  7770. packed.mutation_outputs.append(
  7771. MutationOutput(NoneLayout(device=device), kwargs["out"], packed)
  7772. )
  7773. # For out-variant collective ops, the `out=` arg is guaranteed to be alias of the returned value of op.
  7774. packed.alias_names.append(kwargs["out"].get_name())
  7775. # NOTE: [Out-of-Place Collective Safety]
  7776. # Between the initiation and completion of an out-of-place collective:
  7777. #
  7778. # Input buffers:
  7779. # - Are subject to volatile reads
  7780. # - Can be read by another kernel
  7781. # - Must not be written to or reused by another kernel
  7782. #
  7783. # Output buffers:
  7784. # - Are subject to volatile writes
  7785. # - Must not be read, written to or reused by another kernel
  7786. #
  7787. # To ensure the safety of input buffers without sacrificing read
  7788. # availability, we add input buffers as read deps of wait_tensor kernels.
  7789. #
  7790. # To ensure the safety of output buffers, we model wait_tensor as a
  7791. # mutation to the output buffer. Note we also assumes the user program being
  7792. # correct and the output buffer is not consumed by kernels other than
  7793. # wait_tensor.
  7794. #
  7795. # TODO(yifu): add a pre-grad pass to validate the correctness of collective
  7796. # usage in the user program.
  7797. @classmethod
  7798. def create_out_of_place(
  7799. cls,
  7800. kernel: _OpOverloads,
  7801. inputs: Union[TensorBox, list[TensorBox]],
  7802. *args: Any,
  7803. **kwargs: Any,
  7804. ) -> Union[list[MultiOutput], _CollectiveKernel]:
  7805. with V.graph.fake_mode:
  7806. (
  7807. example_output,
  7808. tensor_args,
  7809. non_tensor_args,
  7810. unflatten_args,
  7811. unbacked_bindings,
  7812. ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
  7813. assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}"
  7814. for tensor_arg in tensor_args:
  7815. tensor_arg.realize()
  7816. if isinstance(example_output, list):
  7817. device = cls.find_device(tensor_args, example_output)
  7818. assert device is not None
  7819. packed = cls(
  7820. MultiOutputLayout(device=device),
  7821. kernel,
  7822. tensor_args,
  7823. non_tensor_args,
  7824. unflatten_args,
  7825. )
  7826. packed.outputs = [
  7827. MultiOutput(
  7828. cls.tensor_to_layout(tensor),
  7829. packed,
  7830. [(list, i)],
  7831. )
  7832. for i, tensor in enumerate(example_output)
  7833. ]
  7834. for buf, tensor in zip(packed.outputs, example_output):
  7835. if config.assume_unaligned_fallback_output or not tensor_is_aligned(
  7836. tensor
  7837. ):
  7838. V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
  7839. return packed.outputs
  7840. else:
  7841. packed = cls(
  7842. cls.tensor_to_layout(example_output),
  7843. kernel,
  7844. tensor_args,
  7845. non_tensor_args,
  7846. unflatten_args,
  7847. )
  7848. if config.assume_unaligned_fallback_output or not tensor_is_aligned(
  7849. example_output
  7850. ):
  7851. V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type]
  7852. packed.outputs = [packed]
  7853. return packed
  7854. class _AllReduce_Kernel(_CollectiveKernel):
  7855. def __init__(
  7856. self,
  7857. layout: OutputSpec,
  7858. kernel: _OpOverloads,
  7859. tensor_args: Sequence[IRNode],
  7860. nontensor_args: Sequence[Any],
  7861. unflatten_args: Callable[..., Any],
  7862. kwargs: Optional[dict[str, Any]] = None,
  7863. *,
  7864. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  7865. ) -> None:
  7866. super().__init__(
  7867. layout,
  7868. kernel,
  7869. tensor_args,
  7870. nontensor_args,
  7871. unflatten_args,
  7872. kwargs=None,
  7873. unbacked_bindings=unbacked_bindings,
  7874. )
  7875. self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce_")
  7876. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7877. wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
  7878. wrapper.generate_extern_kernel_alloc(self)
  7879. if isinstance(self.layout, Layout):
  7880. self.codegen_size_asserts(wrapper)
  7881. class _AllReduceKernel(_CollectiveKernel):
  7882. def __init__(
  7883. self,
  7884. layout: OutputSpec,
  7885. kernel: _OpOverloads,
  7886. tensor_args: Sequence[IRNode],
  7887. nontensor_args: Sequence[Any],
  7888. unflatten_args: Callable[..., Any],
  7889. kwargs: Optional[dict[str, Any]] = None,
  7890. *,
  7891. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  7892. ) -> None:
  7893. super().__init__(
  7894. layout,
  7895. kernel,
  7896. tensor_args,
  7897. nontensor_args,
  7898. unflatten_args,
  7899. kwargs=None,
  7900. unbacked_bindings=unbacked_bindings,
  7901. )
  7902. self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce")
  7903. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7904. wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
  7905. wrapper.generate_extern_kernel_alloc(self)
  7906. if isinstance(self.layout, Layout):
  7907. self.codegen_size_asserts(wrapper)
  7908. class _WaitKernel(_CollectiveKernel):
  7909. def __init__(
  7910. self,
  7911. layout: OutputSpec,
  7912. kernel: _OpOverloads,
  7913. tensor_args: Sequence[IRNode],
  7914. nontensor_args: Sequence[Any],
  7915. unflatten_args: Callable[..., Any],
  7916. kwargs: Optional[dict[str, Any]] = None,
  7917. *,
  7918. unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
  7919. ) -> None:
  7920. super().__init__(
  7921. layout,
  7922. kernel,
  7923. tensor_args,
  7924. nontensor_args,
  7925. unflatten_args,
  7926. kwargs=None,
  7927. unbacked_bindings=unbacked_bindings,
  7928. )
  7929. self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_wait_tensor")
  7930. def codegen(self, wrapper: PythonWrapperCodegen) -> None:
  7931. wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
  7932. wrapper.generate_extern_kernel_alloc(self)
  7933. if isinstance(self.layout, Layout):
  7934. self.codegen_size_asserts(wrapper)
  7935. def get_volatile_reads(self) -> Sequence[IRNode]:
  7936. inp = self.inputs[0]
  7937. assert isinstance(inp, IRNode)
  7938. if isinstance(inp, _CollectiveKernel):
  7939. # Out-of-place single-output
  7940. i = inp.inputs[0]
  7941. assert isinstance(i, IRNode), type(i)
  7942. return [i]
  7943. elif isinstance(inp, MultiOutput):
  7944. # This can be two things:
  7945. # 1. Out-of-place multi-output coll
  7946. # 2. In-place coll with inputs coming from another MultiOutput
  7947. coll = inp.inputs[0]
  7948. # Case 1
  7949. if isinstance(coll, _CollectiveKernel):
  7950. _, idx = inp.indices[0]
  7951. return [coll.inputs[idx]]
  7952. # Case 2
  7953. return []
  7954. else:
  7955. # In-place requires no additional deps handling for volatile
  7956. # reads since the inputs are mutated.
  7957. return []
  7958. @classmethod
  7959. def create_wait(cls, kernel: _OpOverloads, inp: TensorBox) -> None:
  7960. with V.graph.fake_mode:
  7961. (
  7962. _example_output,
  7963. tensor_args,
  7964. non_tensor_args,
  7965. unflatten_args,
  7966. unbacked_bindings,
  7967. ) = cls.process_kernel(kernel, inp)
  7968. assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
  7969. packed = cls(
  7970. NoneLayout(device=inp.get_device()),
  7971. kernel,
  7972. tensor_args,
  7973. non_tensor_args,
  7974. unflatten_args,
  7975. )
  7976. packed.mutation_outputs.append(
  7977. MutationOutput(NoneLayout(device=inp.get_device()), inp, packed)
  7978. )
  7979. def get_read_writes(self) -> dependencies.ReadWrites:
  7980. read_writes = super().get_read_writes()
  7981. # See [Out-of-Place Collective Safety].
  7982. volatile_reads = self.get_volatile_reads()
  7983. for vr in volatile_reads:
  7984. read_writes.reads.add(dependencies.StarDep(vr.get_name()))
  7985. return read_writes
  7986. # NB: recursive structure here reflects val_to_arg_str, avoid
  7987. # calling free_unbacked_symbols on "exotic" types that don't get pexpr
  7988. # treatment
  7989. def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]:
  7990. if isinstance(s, (SymTypes, Expr)):
  7991. # This branch should be impossible in return position
  7992. return free_unbacked_symbols(s)
  7993. elif isinstance(s, (tuple, list)):
  7994. r = OrderedSet[sympy.Symbol]()
  7995. for t in s:
  7996. r |= maybe_free_unbacked_symbols(t)
  7997. return r
  7998. elif isinstance(s, torch.Tensor):
  7999. # This branch is impossible in constant-args position
  8000. return free_unbacked_symbols(s)
  8001. else:
  8002. return OrderedSet()
  8003. def maybe_free_symbols(s: object) -> OrderedSet[Symbol]:
  8004. if isinstance(s, (SymTypes, Expr)):
  8005. # This branch should be impossible in return position
  8006. return free_symbols(s)
  8007. elif isinstance(s, (tuple, list)):
  8008. r = OrderedSet[sympy.Symbol]()
  8009. for t in s:
  8010. r |= maybe_free_symbols(t)
  8011. return r
  8012. elif isinstance(s, torch.Tensor):
  8013. # This branch is impossible in constant-args position
  8014. return free_symbols(s)
  8015. else:
  8016. return OrderedSet()