| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920292129222923292429252926292729282929293029312932293329342935293629372938293929402941294229432944294529462947294829492950295129522953295429552956295729582959296029612962296329642965296629672968296929702971297229732974297529762977297829792980298129822983298429852986298729882989299029912992299329942995299629972998299930003001300230033004300530063007300830093010301130123013301430153016301730183019302030213022302330243025302630273028302930303031303230333034303530363037303830393040304130423043304430453046304730483049305030513052305330543055305630573058305930603061306230633064306530663067306830693070307130723073307430753076307730783079308030813082308330843085308630873088308930903091309230933094309530963097309830993100310131023103310431053106310731083109311031113112311331143115311631173118311931203121312231233124312531263127312831293130313131323133313431353136313731383139314031413142314331443145314631473148314931503151315231533154315531563157315831593160316131623163316431653166316731683169317031713172317331743175317631773178317931803181318231833184318531863187318831893190319131923193319431953196319731983199320032013202320332043205320632073208320932103211321232133214321532163217321832193220322132223223322432253226322732283229323032313232323332343235323632373238323932403241324232433244324532463247324832493250325132523253325432553256325732583259326032613262326332643265326632673268326932703271327232733274327532763277327832793280328132823283328432853286328732883289329032913292329332943295329632973298329933003301330233033304330533063307330833093310331133123313331433153316331733183319332033213322332333243325332633273328332933303331333233333334333533363337333833393340334133423343334433453346334733483349335033513352335333543355335633573358335933603361336233633364336533663367336833693370337133723373337433753376337733783379338033813382338333843385338633873388338933903391339233933394339533963397339833993400340134023403340434053406340734083409341034113412341334143415341634173418341934203421342234233424342534263427342834293430343134323433343434353436343734383439344034413442344334443445344634473448344934503451345234533454345534563457345834593460346134623463346434653466346734683469347034713472347334743475347634773478347934803481348234833484348534863487348834893490349134923493349434953496349734983499350035013502350335043505350635073508350935103511351235133514351535163517351835193520352135223523352435253526352735283529353035313532353335343535353635373538353935403541354235433544354535463547354835493550355135523553355435553556355735583559356035613562356335643565356635673568356935703571357235733574357535763577357835793580358135823583358435853586358735883589359035913592359335943595359635973598359936003601360236033604360536063607360836093610361136123613361436153616361736183619362036213622362336243625362636273628362936303631363236333634363536363637363836393640364136423643364436453646364736483649365036513652365336543655365636573658365936603661366236633664366536663667366836693670367136723673367436753676367736783679368036813682368336843685368636873688368936903691369236933694369536963697369836993700370137023703370437053706370737083709371037113712371337143715371637173718371937203721372237233724372537263727372837293730373137323733373437353736373737383739374037413742374337443745374637473748374937503751375237533754375537563757375837593760376137623763376437653766376737683769377037713772377337743775377637773778377937803781378237833784378537863787378837893790379137923793379437953796379737983799380038013802380338043805380638073808380938103811381238133814381538163817381838193820382138223823382438253826382738283829383038313832383338343835383638373838383938403841384238433844384538463847384838493850385138523853385438553856385738583859386038613862386338643865386638673868386938703871387238733874387538763877387838793880388138823883388438853886388738883889389038913892389338943895389638973898389939003901390239033904390539063907390839093910391139123913391439153916391739183919392039213922392339243925392639273928392939303931393239333934393539363937393839393940394139423943394439453946394739483949395039513952395339543955395639573958395939603961396239633964396539663967396839693970397139723973397439753976397739783979398039813982398339843985398639873988398939903991399239933994399539963997399839994000400140024003400440054006400740084009401040114012401340144015401640174018401940204021402240234024402540264027402840294030403140324033403440354036403740384039404040414042404340444045404640474048404940504051405240534054405540564057405840594060406140624063406440654066406740684069407040714072407340744075407640774078407940804081408240834084408540864087408840894090409140924093409440954096409740984099410041014102410341044105410641074108410941104111411241134114411541164117411841194120412141224123412441254126412741284129413041314132413341344135413641374138413941404141414241434144414541464147414841494150415141524153415441554156415741584159416041614162416341644165416641674168416941704171417241734174417541764177417841794180418141824183418441854186418741884189419041914192419341944195419641974198419942004201420242034204420542064207420842094210421142124213421442154216421742184219422042214222422342244225422642274228422942304231423242334234423542364237423842394240424142424243424442454246424742484249425042514252425342544255425642574258425942604261426242634264426542664267426842694270427142724273427442754276427742784279428042814282428342844285428642874288428942904291429242934294429542964297429842994300430143024303430443054306430743084309431043114312431343144315431643174318431943204321432243234324432543264327432843294330433143324333433443354336433743384339434043414342434343444345434643474348434943504351435243534354435543564357435843594360436143624363436443654366436743684369437043714372437343744375437643774378437943804381438243834384438543864387438843894390439143924393439443954396439743984399440044014402440344044405440644074408440944104411441244134414441544164417441844194420442144224423442444254426442744284429443044314432443344344435443644374438443944404441444244434444444544464447444844494450445144524453445444554456445744584459446044614462446344644465446644674468446944704471447244734474447544764477447844794480448144824483448444854486448744884489449044914492449344944495449644974498449945004501450245034504450545064507450845094510451145124513451445154516451745184519452045214522452345244525452645274528452945304531453245334534453545364537453845394540454145424543454445454546454745484549455045514552455345544555455645574558455945604561456245634564456545664567456845694570457145724573457445754576457745784579458045814582458345844585458645874588458945904591459245934594459545964597459845994600460146024603460446054606460746084609461046114612461346144615461646174618461946204621462246234624462546264627462846294630463146324633463446354636463746384639464046414642464346444645464646474648464946504651465246534654465546564657465846594660466146624663466446654666466746684669467046714672467346744675467646774678467946804681468246834684468546864687468846894690469146924693469446954696469746984699470047014702470347044705470647074708470947104711471247134714471547164717471847194720472147224723472447254726472747284729473047314732473347344735473647374738473947404741474247434744474547464747474847494750475147524753475447554756475747584759476047614762476347644765476647674768476947704771477247734774477547764777477847794780478147824783478447854786478747884789479047914792479347944795479647974798479948004801480248034804480548064807480848094810481148124813481448154816481748184819482048214822482348244825482648274828482948304831483248334834483548364837483848394840484148424843484448454846484748484849485048514852485348544855485648574858485948604861486248634864486548664867486848694870487148724873487448754876487748784879488048814882488348844885488648874888488948904891489248934894489548964897489848994900490149024903490449054906490749084909491049114912491349144915491649174918491949204921492249234924492549264927492849294930493149324933493449354936493749384939494049414942494349444945494649474948494949504951495249534954495549564957495849594960496149624963496449654966496749684969497049714972497349744975497649774978497949804981498249834984498549864987498849894990499149924993499449954996499749984999500050015002500350045005500650075008500950105011501250135014501550165017501850195020502150225023502450255026502750285029503050315032503350345035503650375038503950405041504250435044504550465047504850495050505150525053505450555056505750585059506050615062506350645065506650675068506950705071507250735074507550765077507850795080508150825083508450855086508750885089509050915092509350945095509650975098509951005101510251035104510551065107510851095110511151125113511451155116511751185119512051215122512351245125512651275128512951305131513251335134513551365137513851395140514151425143514451455146514751485149515051515152515351545155515651575158515951605161516251635164516551665167516851695170517151725173517451755176517751785179518051815182518351845185518651875188518951905191519251935194519551965197519851995200520152025203520452055206520752085209521052115212521352145215521652175218521952205221522252235224522552265227522852295230523152325233523452355236523752385239524052415242524352445245524652475248524952505251525252535254525552565257525852595260526152625263526452655266526752685269527052715272527352745275527652775278527952805281528252835284528552865287528852895290529152925293529452955296529752985299530053015302530353045305530653075308530953105311531253135314531553165317531853195320532153225323532453255326532753285329533053315332533353345335533653375338533953405341534253435344534553465347534853495350535153525353535453555356535753585359536053615362536353645365536653675368536953705371537253735374537553765377537853795380538153825383538453855386538753885389539053915392539353945395539653975398539954005401540254035404540554065407540854095410541154125413541454155416541754185419542054215422542354245425542654275428542954305431543254335434543554365437543854395440544154425443544454455446544754485449545054515452545354545455545654575458545954605461546254635464546554665467546854695470547154725473547454755476547754785479548054815482548354845485548654875488548954905491549254935494549554965497549854995500550155025503550455055506550755085509551055115512551355145515551655175518551955205521552255235524552555265527552855295530553155325533553455355536553755385539554055415542554355445545554655475548554955505551555255535554555555565557555855595560556155625563556455655566556755685569557055715572557355745575557655775578557955805581558255835584558555865587558855895590559155925593559455955596559755985599560056015602560356045605560656075608560956105611561256135614561556165617561856195620562156225623562456255626562756285629563056315632563356345635563656375638563956405641564256435644564556465647564856495650565156525653565456555656565756585659566056615662566356645665566656675668566956705671567256735674567556765677567856795680568156825683568456855686568756885689569056915692569356945695569656975698569957005701570257035704570557065707570857095710571157125713571457155716571757185719572057215722572357245725572657275728572957305731573257335734573557365737573857395740574157425743574457455746574757485749575057515752575357545755575657575758575957605761576257635764576557665767576857695770577157725773577457755776577757785779578057815782578357845785578657875788578957905791579257935794579557965797579857995800580158025803580458055806580758085809581058115812581358145815581658175818581958205821582258235824582558265827582858295830583158325833583458355836583758385839584058415842584358445845584658475848584958505851585258535854585558565857585858595860586158625863586458655866586758685869587058715872587358745875587658775878587958805881588258835884588558865887588858895890589158925893589458955896589758985899590059015902590359045905590659075908590959105911591259135914591559165917591859195920592159225923592459255926592759285929593059315932593359345935593659375938593959405941594259435944594559465947594859495950595159525953595459555956595759585959596059615962596359645965596659675968596959705971597259735974597559765977597859795980598159825983598459855986598759885989599059915992599359945995599659975998599960006001600260036004600560066007600860096010601160126013601460156016601760186019602060216022602360246025602660276028602960306031603260336034603560366037603860396040604160426043604460456046604760486049605060516052605360546055605660576058605960606061606260636064606560666067606860696070607160726073607460756076607760786079608060816082608360846085608660876088608960906091609260936094609560966097609860996100610161026103610461056106610761086109611061116112611361146115611661176118611961206121612261236124612561266127612861296130613161326133613461356136613761386139614061416142614361446145614661476148614961506151615261536154615561566157615861596160616161626163616461656166616761686169617061716172617361746175617661776178617961806181618261836184618561866187618861896190619161926193619461956196619761986199620062016202620362046205620662076208620962106211621262136214621562166217621862196220622162226223622462256226622762286229623062316232623362346235623662376238623962406241624262436244624562466247624862496250625162526253625462556256625762586259626062616262626362646265626662676268626962706271627262736274627562766277627862796280628162826283628462856286628762886289629062916292629362946295629662976298629963006301630263036304630563066307630863096310631163126313631463156316631763186319632063216322632363246325632663276328632963306331633263336334633563366337633863396340634163426343634463456346634763486349635063516352635363546355635663576358635963606361636263636364636563666367636863696370637163726373637463756376637763786379638063816382638363846385638663876388638963906391639263936394639563966397639863996400640164026403640464056406640764086409641064116412641364146415641664176418641964206421642264236424642564266427642864296430643164326433643464356436643764386439644064416442644364446445644664476448644964506451645264536454645564566457645864596460646164626463646464656466646764686469647064716472647364746475647664776478647964806481648264836484648564866487648864896490649164926493649464956496649764986499650065016502650365046505650665076508650965106511651265136514651565166517651865196520652165226523652465256526652765286529653065316532653365346535653665376538653965406541654265436544654565466547654865496550655165526553655465556556655765586559656065616562656365646565656665676568656965706571657265736574657565766577657865796580658165826583658465856586658765886589659065916592659365946595659665976598659966006601660266036604660566066607660866096610661166126613661466156616661766186619662066216622662366246625662666276628662966306631663266336634663566366637663866396640664166426643664466456646664766486649665066516652665366546655665666576658665966606661666266636664666566666667666866696670667166726673667466756676667766786679668066816682668366846685668666876688668966906691669266936694669566966697669866996700670167026703670467056706670767086709671067116712671367146715671667176718671967206721672267236724672567266727672867296730673167326733673467356736673767386739674067416742674367446745674667476748674967506751675267536754675567566757675867596760676167626763676467656766676767686769677067716772677367746775677667776778677967806781678267836784678567866787678867896790679167926793679467956796679767986799680068016802680368046805680668076808680968106811681268136814681568166817681868196820682168226823682468256826682768286829683068316832683368346835683668376838683968406841684268436844684568466847684868496850685168526853685468556856685768586859686068616862686368646865686668676868686968706871687268736874687568766877687868796880688168826883688468856886688768886889689068916892689368946895689668976898689969006901690269036904690569066907690869096910691169126913691469156916691769186919692069216922692369246925692669276928692969306931693269336934693569366937693869396940694169426943694469456946694769486949695069516952695369546955695669576958695969606961696269636964696569666967696869696970697169726973697469756976697769786979698069816982698369846985698669876988698969906991699269936994699569966997699869997000700170027003700470057006700770087009701070117012701370147015701670177018701970207021702270237024702570267027702870297030703170327033703470357036703770387039704070417042704370447045704670477048704970507051705270537054705570567057705870597060706170627063706470657066706770687069707070717072707370747075707670777078707970807081708270837084708570867087708870897090709170927093709470957096709770987099710071017102710371047105710671077108710971107111711271137114711571167117711871197120712171227123712471257126712771287129713071317132713371347135713671377138713971407141714271437144714571467147714871497150715171527153715471557156715771587159716071617162716371647165716671677168716971707171717271737174717571767177717871797180718171827183718471857186718771887189719071917192719371947195719671977198719972007201720272037204720572067207720872097210721172127213721472157216721772187219722072217222722372247225722672277228722972307231723272337234723572367237723872397240724172427243724472457246724772487249725072517252725372547255725672577258725972607261726272637264726572667267726872697270727172727273727472757276727772787279728072817282728372847285728672877288728972907291729272937294729572967297729872997300730173027303730473057306730773087309731073117312731373147315731673177318731973207321732273237324732573267327732873297330733173327333733473357336733773387339734073417342734373447345734673477348734973507351735273537354735573567357735873597360736173627363736473657366736773687369737073717372737373747375737673777378737973807381738273837384738573867387738873897390739173927393739473957396739773987399740074017402740374047405740674077408740974107411741274137414741574167417741874197420742174227423742474257426742774287429743074317432743374347435743674377438743974407441744274437444744574467447744874497450745174527453745474557456745774587459746074617462746374647465746674677468746974707471747274737474747574767477747874797480748174827483748474857486748774887489749074917492749374947495749674977498749975007501750275037504750575067507750875097510751175127513751475157516751775187519752075217522752375247525752675277528752975307531753275337534753575367537753875397540754175427543754475457546754775487549755075517552755375547555755675577558755975607561756275637564756575667567756875697570757175727573757475757576757775787579758075817582758375847585758675877588758975907591759275937594759575967597759875997600760176027603760476057606760776087609761076117612761376147615761676177618761976207621762276237624762576267627762876297630763176327633763476357636763776387639764076417642764376447645764676477648764976507651765276537654765576567657765876597660766176627663766476657666766776687669767076717672767376747675767676777678767976807681768276837684768576867687768876897690769176927693769476957696769776987699770077017702770377047705770677077708770977107711771277137714771577167717771877197720772177227723772477257726772777287729773077317732773377347735773677377738773977407741774277437744774577467747774877497750775177527753775477557756775777587759776077617762776377647765776677677768776977707771777277737774777577767777777877797780778177827783778477857786778777887789779077917792779377947795779677977798779978007801780278037804780578067807780878097810781178127813781478157816781778187819782078217822782378247825782678277828782978307831783278337834783578367837783878397840784178427843784478457846784778487849785078517852785378547855785678577858785978607861786278637864786578667867786878697870787178727873787478757876787778787879788078817882788378847885788678877888788978907891789278937894789578967897789878997900790179027903790479057906790779087909791079117912791379147915791679177918791979207921792279237924792579267927792879297930793179327933793479357936793779387939794079417942794379447945794679477948794979507951795279537954795579567957795879597960796179627963796479657966796779687969797079717972797379747975797679777978797979807981798279837984798579867987798879897990799179927993799479957996799779987999800080018002800380048005800680078008800980108011801280138014801580168017801880198020802180228023802480258026802780288029803080318032803380348035803680378038803980408041804280438044804580468047804880498050805180528053805480558056805780588059806080618062806380648065806680678068806980708071807280738074807580768077807880798080808180828083808480858086808780888089809080918092809380948095809680978098809981008101810281038104810581068107810881098110811181128113811481158116811781188119812081218122812381248125812681278128812981308131813281338134813581368137813881398140814181428143814481458146814781488149815081518152815381548155815681578158815981608161816281638164816581668167816881698170817181728173817481758176817781788179818081818182818381848185818681878188818981908191819281938194819581968197819881998200820182028203820482058206820782088209821082118212821382148215821682178218821982208221822282238224822582268227822882298230823182328233823482358236823782388239824082418242824382448245824682478248824982508251825282538254825582568257825882598260826182628263826482658266826782688269827082718272827382748275827682778278827982808281828282838284828582868287828882898290829182928293829482958296829782988299830083018302830383048305830683078308830983108311831283138314831583168317831883198320832183228323832483258326832783288329833083318332833383348335833683378338833983408341834283438344834583468347834883498350835183528353835483558356835783588359836083618362836383648365836683678368836983708371837283738374837583768377837883798380838183828383838483858386838783888389839083918392839383948395839683978398839984008401840284038404840584068407840884098410841184128413841484158416841784188419842084218422842384248425842684278428842984308431843284338434843584368437843884398440844184428443844484458446844784488449845084518452845384548455845684578458845984608461846284638464846584668467846884698470847184728473847484758476847784788479848084818482848384848485848684878488848984908491849284938494849584968497849884998500850185028503850485058506850785088509851085118512851385148515851685178518851985208521852285238524852585268527852885298530853185328533853485358536853785388539854085418542854385448545854685478548854985508551855285538554855585568557855885598560856185628563856485658566856785688569857085718572857385748575857685778578857985808581858285838584858585868587858885898590859185928593859485958596859785988599860086018602860386048605860686078608860986108611861286138614861586168617861886198620862186228623862486258626862786288629863086318632863386348635863686378638863986408641864286438644864586468647864886498650865186528653865486558656865786588659866086618662866386648665866686678668866986708671867286738674867586768677867886798680868186828683868486858686868786888689869086918692869386948695869686978698869987008701870287038704870587068707870887098710871187128713871487158716871787188719872087218722872387248725872687278728872987308731873287338734873587368737873887398740874187428743874487458746874787488749875087518752875387548755875687578758875987608761876287638764876587668767876887698770877187728773877487758776877787788779878087818782878387848785878687878788878987908791879287938794879587968797879887998800880188028803880488058806880788088809881088118812881388148815881688178818881988208821882288238824882588268827882888298830883188328833883488358836883788388839884088418842884388448845884688478848884988508851885288538854885588568857885888598860886188628863886488658866886788688869887088718872887388748875887688778878887988808881888288838884888588868887888888898890889188928893889488958896889788988899890089018902890389048905890689078908890989108911891289138914891589168917891889198920892189228923892489258926892789288929893089318932893389348935893689378938893989408941894289438944894589468947894889498950895189528953895489558956895789588959896089618962896389648965896689678968896989708971897289738974897589768977897889798980898189828983898489858986898789888989899089918992899389948995899689978998899990009001900290039004900590069007900890099010901190129013901490159016901790189019902090219022902390249025902690279028902990309031903290339034903590369037903890399040904190429043904490459046904790489049905090519052905390549055905690579058905990609061906290639064906590669067906890699070907190729073907490759076907790789079908090819082908390849085908690879088908990909091909290939094909590969097909890999100910191029103910491059106910791089109911091119112911391149115911691179118911991209121912291239124912591269127912891299130913191329133913491359136913791389139914091419142914391449145914691479148914991509151915291539154915591569157915891599160916191629163916491659166916791689169917091719172917391749175917691779178917991809181918291839184918591869187918891899190919191929193919491959196919791989199920092019202920392049205920692079208920992109211921292139214921592169217921892199220922192229223922492259226922792289229923092319232923392349235923692379238923992409241924292439244924592469247924892499250925192529253925492559256925792589259926092619262926392649265926692679268926992709271927292739274927592769277927892799280928192829283928492859286928792889289929092919292929392949295929692979298929993009301930293039304930593069307930893099310931193129313931493159316931793189319932093219322932393249325932693279328932993309331933293339334933593369337933893399340934193429343934493459346934793489349935093519352935393549355935693579358935993609361936293639364936593669367936893699370937193729373937493759376937793789379938093819382938393849385938693879388938993909391939293939394939593969397939893999400940194029403940494059406940794089409 |
- from __future__ import annotations
- import contextlib
- import dataclasses
- import functools
- import itertools
- import logging
- import operator
- import os
- import textwrap
- import traceback
- from collections.abc import Container, Generator, Iterable, Iterator, Sequence
- from contextlib import AbstractContextManager, nullcontext
- from enum import Enum
- from functools import partial
- from typing import (
- Any,
- Callable,
- cast,
- ClassVar,
- Literal,
- Optional,
- overload,
- SupportsFloat,
- SupportsInt,
- TYPE_CHECKING,
- TypeVar,
- Union,
- )
- from typing_extensions import (
- assert_never,
- Never,
- override,
- ParamSpec,
- Self,
- TypeAlias,
- TypeIs,
- )
- from unittest.mock import patch
- import sympy
- from sympy import Expr, Integer, Symbol
- import torch._export.serde.schema as export_schema
- import torch._library.utils as library_utils
- import torch._logging
- import torch.fx
- import torch.utils._pytree as pytree
- from torch._dynamo.utils import identity
- from torch._export.serde.serialize import GraphModuleSerializer
- from torch._higher_order_ops.auto_functionalize import can_auto_functionalize
- from torch._inductor import metrics
- from torch._inductor.utils import get_free_symbols
- from torch._prims_common import (
- compute_required_storage_length,
- is_boolean_dtype,
- is_float_dtype,
- make_channels_last_strides_for,
- StrideType,
- )
- from torch._subclasses.fake_tensor import get_schema_info
- from torch.fx.experimental.symbolic_shapes import (
- _remove_effect_token_unbacked_bindings,
- compute_unbacked_bindings,
- free_symbols,
- free_unbacked_symbols,
- IterateExprs,
- rebind_unbacked,
- resolve_unbacked_bindings,
- ShapeEnv,
- SymTypes,
- )
- from torch.fx.node import Node
- from torch.utils._ordered_set import OrderedSet
- from torch.utils._sympy.functions import CleanDiv, FloorDiv, ModularIndexing
- from torch.utils._sympy.symbol import SymT
- from . import config, dependencies
- from .codegen.common import (
- BackendFeature,
- CodegenSymbol,
- get_scheduling_for_device,
- index_prevent_reordering,
- Kernel,
- )
- from .dependencies import (
- Dep,
- extract_free_symbols,
- extract_input_node_reduction_ranges,
- extract_read_writes,
- var_builder,
- )
- from .loop_body import LoopBody
- from .ops_handler import OpCounterCSE, OpCountResult, ReductionType, StoreMode
- from .runtime.benchmarking import benchmarker
- from .runtime.hints import DeviceProperties, ReductionHint
- from .utils import (
- argsort,
- argsort_sym,
- cache_on_self,
- cache_on_self_and_args,
- ceildiv,
- convert_shape_to_inductor,
- convert_shape_to_symint,
- developer_warning,
- do_bench_using_profiling,
- dtype_from_size,
- get_dtype_size,
- get_kernel_metadata,
- GPU_ALIGN_BYTES,
- ir_dataclass,
- is_dynamic,
- is_gpu,
- sympy_dot,
- sympy_index_symbol,
- sympy_index_symbol_with_prefix,
- sympy_product,
- sympy_subs,
- tensor_is_aligned,
- )
- from .virtualized import ops, OpsValue, V
- if TYPE_CHECKING:
- from torch._library.fake_class_registry import FakeScriptObject
- from torch.fx.experimental.symbolic_shapes import SympyBoolean
- from torch.fx.node import Argument
- from .codegen.cuda.cuda_template import CUDATemplate
- from .codegen.wrapper import PythonWrapperCodegen
- from .graph import GraphLowering
- from .utils import IndentedBuffer
- else:
- CUDATemplate: TypeAlias = object
- try:
- import triton
- triton_version = triton.__version__
- has_triton = True
- except ImportError:
- triton_version = None
- has_triton = False
- _P = ParamSpec("_P")
- _T = TypeVar("_T")
- _U = TypeVar("_U")
- _V = TypeVar("_V")
- _IntLike: TypeAlias = Union[int, Expr]
- _NumLike: TypeAlias = Union[int, float, Expr]
- _OpOverloads: TypeAlias = Union[torch._ops.OpOverload, torch._ops.HigherOrderOperator]
- log = logging.getLogger(__name__)
- indent = functools.partial(textwrap.indent, prefix=" ")
- aten = torch.ops.aten
- autotune_warmup = int(os.getenv("TORCH_AUTOTUNE_WARMUP", 25))
- autotune_rep = int(os.getenv("TORCH_AUTOTUNE_REP", 100))
- """ [Note: Inductor IR]
- Inductor's IR is produced by executing 'lowering' code (see lowering.py). Each
- lowering is registered to a particular aten operator, and expects inputs that
- correspond to the aten schema. However, in place of torch Tensor inputs, lowerings
- expect Inductor TensorBox inputs.
- TensorBox IR represents torch tensors. Tensors are sometimes single objects owning
- storage, and sometimes views of another Tensor's storage. Mutating tensor operations
- (such as add_()) affect the underlying storage and any associated views. Other operations
- (such as .t_()) update metadata about the current view but don't modify the underlying storage.
- To model this in Inductor, the IR distinguishes between TensorBox, View, StorageBox and Buffer.
- TensorBox is the top level IR construct that any lowering should produce and maps to a torch.Tensor
- output from an operation. But just as torch.Tensors take different forms, TensorBox IR can
- reference View IR or directly reference StorageBox IRs.
- Some Inductor lowerings produce new sets of 'Box'es, while others (such as .t() or other view ops)
- may take an existing TensorBox and point it to a new underlying View IR.
- Tensors that directly own storage are represented as a chain of:
- TensorBox -> StorageBox -> Buffer
- where Buffer is a simple (1D) allocation, and StorageBox introduces the concept of a Layout.
- If you mutate the data of such a tensor, we swing the StorageBox pointer to point to a new buffer
- (leaving the old buffer unmodified and functionalizing the operation).
- Tensors backed by views add one more indirection to the IR.
- TensorBox -> View -> StorageBox -> Buffer
- In these cases, the underlying StorageBox/Buffer will be shared with the pre-view TensorBox.
- Computation is represented by Operation nodes, with each operation producing 1
- or more output Buffers. In the case of mutations, these will be new Buffers that have the
- mutated buffer listed in its get_mutation_names().
- It is also possible to have an InputBuffer for which there is no corresponding Operation,
- e.g. it may be a graph input or compile time constant.
- """
- _NodeOrNodes: TypeAlias = Union[
- int,
- "TensorBox",
- dict[str, "TensorBox"],
- "Symbol",
- "IRNode",
- Sequence[
- Optional[Union[int, dict[str, "TensorBox"], "TensorBox", "Symbol", "IRNode"]]
- ],
- ]
- def _is_static(x: object) -> bool:
- return isinstance(x, (int, Integer))
- @dataclasses.dataclass(frozen=True)
- class GraphPartitionSignature:
- # symbol inputs that are necessary for codegen
- symbol_inputs: OrderedSet[sympy.Symbol]
- # mapping from partition input name to IRNode or Expr. Need the name str since
- # we cannot get name from Expr.
- input_nodes: dict[str, Union[IRNode, sympy.Expr, TorchBindObject]]
- output_nodes: list[IRNode]
- # mapping from partition input name to a boolean for whether deallocating it
- # in the partition function
- input_deallocation: dict[str, bool]
- skip_cudagraph: bool
- # name of constants read/written by the graph partition
- constant_names: list[str]
- def validate_ir(node_or_nodes: Optional[_NodeOrNodes]) -> None:
- def _check_tensorbox(nodes: Optional[_NodeOrNodes]) -> None:
- # Could expand this to check deeper properties
- # (e.g. TensorBox points to View or StorageBox)
- if nodes is None:
- pass
- elif isinstance(nodes, (list, tuple)):
- for node in nodes:
- _check_tensorbox(node)
- elif isinstance(nodes, dict):
- for node in nodes.values():
- _check_tensorbox(node)
- else:
- assert isinstance(
- nodes,
- (
- ExpandView,
- DynamicScalar,
- AssertScalar,
- TensorBox,
- sympy.logic.boolalg.Boolean,
- Expr,
- int,
- EffectfulKernel,
- ShapeAsConstantBuffer,
- ),
- ), (
- f"Found {type(nodes)}, which is not a supported top level IR node. See [Note: Inductor IR]"
- )
- # Be picky about the accepted data structure (don't use pytree here)
- _check_tensorbox(node_or_nodes)
- def ops_wrapper(name: str) -> Callable[..., OpsValue]:
- assert isinstance(name, str), type(name)
- def fn(*args: object, **kwargs: object) -> OpsValue:
- return getattr(ops, name)(*args, **kwargs)
- return fn
- def inverse_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]:
- inv_order = dict(zip(order, range(len(order))))
- def reindex(index: Sequence[_T]) -> Sequence[_T]:
- assert len(index) == len(inv_order)
- return [index[inv_order[i]] for i in range(len(index))]
- return reindex
- def same_reorder(order: Sequence[int]) -> Callable[[Sequence[_T]], Sequence[_T]]:
- def reindex(index: Sequence[_T]) -> Sequence[_T]:
- assert len(index) == len(order)
- return [index[order[i]] for i in range(len(index))]
- return reindex
- def fuse_reindexing(
- reindex1: Callable[[Sequence[_U]], Sequence[_V]],
- reindex2: Callable[[Sequence[_T]], Sequence[_U]],
- ) -> Callable[[Sequence[_T]], Sequence[_V]]:
- def reindex(index: Sequence[_T]) -> Sequence[_V]:
- return reindex1(reindex2(index))
- return reindex
- NHWC_STRIDE_ORDER = [3, 0, 2, 1]
- NHWDC_STRIDE_ORDER = [4, 0, 3, 2, 1]
- def get_fill_order(
- seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None
- ) -> Sequence[int]:
- """
- Convert strides to fill order (argsort)
- """
- if shape_env is None or all(isinstance(s, (int, sympy.Integer)) for s in seq):
- sorted_idx: Sequence[int] = argsort(seq)
- else:
- # argsort_sym handles unbacked symints (with the help of the shape_env)
- sorted_idx = argsort_sym(shape_env, seq)
- return sorted_idx
- def stride_order2fill_order(order: Sequence[Union[int, Integer]]) -> Sequence[int]:
- """
- Convert stride order to fill order
- For channel last format,
- stride order = [3, 0, 2, 1] and fill order = [1, 3, 2, 0]
- """
- lookup = {pos: idx for idx, pos in enumerate(order)}
- fill_order = [lookup[i] for i in range(len(order))]
- return fill_order
- def get_stride_order(
- seq: Sequence[Union[int, torch.SymInt, Expr]], shape_env: Optional[ShapeEnv] = None
- ) -> Sequence[int]:
- """
- Convert strides to stride order
- """
- sorted_idx: Sequence[int] = get_fill_order(seq, shape_env)
- out = [0 for _ in range(len(seq))]
- for i, elem in enumerate(sorted_idx):
- out[elem] = i
- return out
- @overload
- def ir_node_to_tensor(x: Literal[None], guard_shape: bool = True) -> None: ...
- @overload
- def ir_node_to_tensor(x: IRNode, guard_shape: bool = True) -> torch.Tensor: ...
- def ir_node_to_tensor(
- x: Optional[IRNode], guard_shape: bool = True
- ) -> Optional[torch.Tensor]:
- if x is None:
- return None
- shape_fn: Callable[[Union[int, Expr]], Union[int, Expr]]
- if not guard_shape:
- shape_fn = V.graph.sizevars.size_hint
- else:
- shape_fn = identity
- size = [shape_fn(s) for s in x.get_size()]
- stride: StrideType
- if is_storage_and_layout(x):
- stride = [shape_fn(s) for s in x.get_layout().stride]
- else:
- stride = FlexibleLayout.contiguous_strides(size)
- dtype = x.get_dtype()
- device = x.get_device()
- size = convert_shape_to_symint(size)
- stride = convert_shape_to_symint(stride)
- with V.graph.sizevars.shape_env.suppress_guards():
- t = torch.empty_strided(
- size=size, stride=stride, dtype=dtype, device=device
- ).zero_()
- return t
- def may_convert_to_optional(
- value: Optional[Sequence[_T]],
- ) -> Optional[Sequence[Optional[_T]]]:
- if isinstance(value, list) and not value:
- # [None] makes sure the cpp wrapper codegen will generate something like
- # {std::nullopt} instead of {}
- return [None]
- return value
- def get_device_type(
- x: Union[IRNode, OutputSpec, torch.device, None, str],
- ) -> Optional[str]:
- if isinstance(x, str) or x is None:
- return x
- elif isinstance(x, torch.device):
- return x.type
- elif isinstance(x, (IRNode, OutputSpec)):
- return get_device_type(x.get_device())
- assert_never(f"get_device_type({x}: {type(x).__name__})")
- def is_triton(x: Union[IRNode, torch.device, None, str]) -> bool:
- device = get_device_type(x)
- # Special case cpu and cuda as using the method below
- # to determine if the scheduler is a triton scheduler subclass
- # requires instantiating a scheduler for them
- if device in ["cpu", "cuda"]:
- if getattr(config, f"{device}_backend") == "triton":
- return True
- return False
- if (
- device is None
- or (device_scheduling := get_scheduling_for_device(device)) is None
- ):
- return False
- from .codegen.triton import TritonScheduling
- assert isinstance(device_scheduling, type), type(device_scheduling)
- return issubclass(device_scheduling, TritonScheduling)
- def is_cpu(x: Union[IRNode, torch.device, None, str]) -> bool:
- return get_device_type(x) == "cpu"
- def is_aligned_realized_tensor_hint(
- x: Union[Buffer, TensorBox], alignment: int
- ) -> bool:
- # Use this as a hint. This won't guard since size_hint doesn't guard.
- if (
- not isinstance(x, IRNode)
- or x.maybe_get_stride() is None
- or free_unbacked_symbols(x.get_stride())
- or free_unbacked_symbols(x.get_size())
- ):
- return False
- aligned_strides = all(
- (V.graph.sizevars.size_hint_or_throw(x.get_stride()[i]) % alignment) == 0
- for i in range(len(x.get_stride()) - 1)
- )
- # if the last dim size is <= 1, stride doesn't matter
- aligned_last_dim = (
- V.graph.sizevars.size_hint_or_throw(x.get_stride()[-1]) == 1
- or V.graph.sizevars.size_hint_or_throw(x.get_size()[-1]) <= 1
- )
- return aligned_last_dim and aligned_strides
- def significant_strides_equal(
- strides1: Sequence[_IntLike],
- strides2: Sequence[_IntLike],
- shape: Sequence[_IntLike],
- ) -> bool:
- """
- Returns true if the strides are equal, ignoring dimensions of size 1 .
- """
- assert len(shape) == len(strides1) and len(strides1) == len(strides2)
- for dim, s1, s2 in zip(shape, strides1, strides2):
- if V.graph.sizevars.statically_known_leq(dim, 1):
- continue
- if not V.graph.sizevars.statically_known_equals(
- s1, s2
- ) and not V.graph.sizevars.symbolic_hint(s1) == V.graph.sizevars.symbolic_hint(
- s2
- ):
- return False
- return True
- def try_match_insignificant_strides(
- tensor: IRNode,
- strides: Sequence[Union[int, torch.SymInt]],
- ) -> IRNode:
- """
- Tries to match the strides of the tensor to those in the meta_strides. Strides of insignificant
- dimensions - size 0 or 1 - will be updated.
- If there are real stride differences (NHWC vs NCHW), or the tensor is not realized, then the input will be returned
- """
- if not is_storage_and_layout(tensor):
- return tensor
- if all(
- V.graph.sizevars.statically_known_equals(s1, s2)
- for s1, s2 in zip(strides, tensor.get_stride())
- ):
- return tensor
- if not significant_strides_equal(strides, tensor.get_stride(), tensor.get_size()):
- return tensor
- storage, old_layout = as_storage_and_layout(tensor)
- new_stride = [*old_layout.stride]
- for i, s in enumerate(tensor.get_size()):
- if V.graph.sizevars.statically_known_leq(s, 1):
- new_stride[i] = strides[i]
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- old_layout.size,
- new_stride,
- old_layout.offset,
- old_layout.is_pinned,
- )
- return TensorBox(ReinterpretView(data=storage, layout=new_layout))
- def gm_original_output_strides(gm: torch.fx.GraphModule) -> None:
- output_node = gm.graph.find_nodes(op="output")[0]
- output_node.meta["user_visible_output_idxs"] = [
- idx for idx, _ in enumerate(output_node.args)
- ]
- from torch._inductor.compile_fx import record_original_output_strides
- record_original_output_strides(gm)
- def get_symbolic_inputs(inputs: Sequence[IRNode]) -> list[Expr]:
- sym_vars: OrderedSet[Expr] = OrderedSet()
- for inp in inputs:
- sym_vars |= get_free_symbols(inp.get_size(), unbacked_only=False)
- sym_vars |= get_free_symbols(inp.get_stride(), unbacked_only=False)
- return list(sym_vars)
- class IRNode:
- """Base class for all intermediate representation (IR) nodes in TorchInductor.
- Note:
- This is an abstract base class. Most methods raise NotImplementedError
- and must be overridden by concrete subclasses.
- """
- _current_origins: ClassVar[OrderedSet[Any]] = OrderedSet()
- # NB: These are kinda weird,
- origins: OrderedSet[Any] = dataclasses.field(init=False)
- # traces back to where the IRNode is created in Inductor
- traceback: Optional[list[str]] = dataclasses.field(init=False)
- origin_node: Optional[torch.fx.Node] = dataclasses.field(init=False)
- @staticmethod
- @contextlib.contextmanager
- def current_origins(origins: OrderedSet[Node]) -> Generator[None, None, None]:
- old = IRNode._current_origins
- IRNode._current_origins = old | origins
- try:
- yield
- finally:
- IRNode._current_origins = old
- @staticmethod
- def is_realized_node(node: IRNode) -> bool:
- return isinstance(
- node,
- (
- ComputedBuffer,
- InputsKernel,
- InputBuffer,
- ReinterpretView,
- TemplateBuffer,
- ),
- )
- def _post_init_setattr(self, attr: str, value: Any) -> None:
- # Intended for use in __post_init__ for enforcing an invariant on a dataclass
- # If you must, can also be used for setting provenance info
- # We would like to try and minimize these usages though
- object.__setattr__(self, attr, value)
- def __post_init__(self) -> None:
- origins = OrderedSet(self._current_origins)
- self._post_init_setattr("origins", origins)
- self._post_init_setattr(
- "traceback", traceback.format_stack() if config.debug_ir_traceback else None
- )
- self._post_init_setattr("origin_node", None)
- def get_read_names(self) -> OrderedSet[str]:
- return OrderedSet(dep.name for dep in self.get_reads())
- def get_traceback(self) -> Optional[list[str]]:
- return self.traceback
- def get_origin_node(self) -> Optional[torch.fx.Node]:
- return self.origin_node
- def get_defining_op(self) -> Optional[Operation]:
- return None
- def get_stack_traces(self) -> OrderedSet[str]:
- # Return stack traces to user model code
- # A single IRNode could correspond to multiple lines of code
- stack_traces: OrderedSet[str] = OrderedSet()
- origins = self.origins
- if isinstance(self, ExternKernel):
- origin_node = self.get_origin_node()
- if self.origin_node:
- origins = OrderedSet([origin_node])
- for node in origins:
- if hasattr(node, "stack_trace") and node.stack_trace:
- # nodes in the backward graph don't have mapping to pre_grad_graph
- stack_traces.add(node.stack_trace)
- else:
- pre_grad_nodes = (
- torch._inductor.debug._inductor_post_to_pre_grad_nodes.get(
- "postToPre", {}
- ).get(node.name, [])
- )
- if not isinstance(pre_grad_nodes, list):
- continue
- for node_name in pre_grad_nodes:
- stack_trace = (
- torch._inductor.debug._inductor_pre_grad_node_stack_trace.get(
- node_name, None
- )
- )
- if stack_trace:
- stack_traces.add(stack_trace)
- return stack_traces
- def common_repr(self, shorten: bool = True) -> Sequence[str]:
- origins = f"origins={getattr(self, 'origins', '')}"
- if shorten and len(origins) > 64:
- # this can get *very* long
- origins = f"{origins[:61]}..."
- if not self.get_stack_traces():
- return [origins]
- stack_trace_str = []
- for stack_trace in self.get_stack_traces():
- stack_trace_str.append("stack_traces = {")
- stack_trace_str += stack_trace.split("\n")
- stack_trace_str.append("}")
- return [origins] + stack_trace_str
- def str_helper(
- self, lines: Sequence[object], shorten: bool = True, multiline: bool = True
- ) -> str:
- lines = list(lines) + list(self.common_repr(shorten))
- lines = list(map(str, lines))
- if multiline:
- new_lines = indent(",\n".join(lines))
- return f"{type(self).__name__}(\n{new_lines}\n)"
- else:
- return f"{type(self).__name__}({lines})"
- def get_dtype(self) -> torch.dtype:
- return self.dtype
- def maybe_get_dtype(self) -> Optional[torch.dtype]:
- try:
- return self.get_dtype()
- except NotImplementedError:
- return None
- def get_layout(self) -> Layout:
- raise NotImplementedError(f"get_layout() is not implemented by {type(self)}!")
- def maybe_get_layout(self) -> Optional[Layout]:
- try:
- return self.get_layout()
- except NotImplementedError:
- return None
- def get_output_spec(self) -> OutputSpec:
- return self.get_layout()
- def maybe_get_output_spec(self) -> Optional[OutputSpec]:
- try:
- return self.get_output_spec()
- except NotImplementedError:
- return None
- def has_tensor_output(self) -> bool:
- """True for single tensor output (excludes MultiOutput)"""
- return isinstance(self.maybe_get_output_spec(), Layout)
- def get_size(self) -> Sequence[Expr]:
- raise NotImplementedError(f"get_size() is not implemented by {type(self)}!")
- def maybe_get_size(self) -> Optional[Sequence[_IntLike]]:
- try:
- return self.get_size()
- except NotImplementedError:
- return None
- @property
- def shape(self) -> Union[_IntLike, sympy.Rel, Sequence[_IntLike]]:
- return self.get_size()
- def get_numel(self) -> Expr:
- return sympy_product(self.get_size())
- def is_zero_elements(self) -> bool:
- return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
- def realize(self) -> Optional[str]:
- """
- If the IRNode refers to data which has not been materialized (e.g.,
- it is a Pointwise/Reduction that could potentially have more
- compute fused into it), realize the IRNode into physical memory,
- ending the possibility of fusing into it, but allowing, e.g., multiple
- users to access the data without having to recompute.
- Check StorageBox.realize for a particularly notable implementation.
- TODO(ezyang): I think, in principle, every IRNode should have an
- implementation of this, and most of the time no-op is OK, but you
- really do have to audit each IRNode for this, so for now, raise
- an error if it's not implemented. Note that some code in graph.py
- will catch this thrown error and suppress it with a warning.
- """
- raise NotImplementedError(f"realize NYI on {type(self)}")
- def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
- raise NotImplementedError(f"codegen_reference NYI on {type(self)}")
- def get_device(self) -> Optional[torch.device]:
- return None
- def get_device_or_error(self) -> torch.device:
- device = self.get_device()
- assert device is not None
- return device
- def has_exceeded_max_reads(self) -> bool:
- return False
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- raise NotImplementedError(type(self).__name__)
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- raise NotImplementedError(type(self).__name__)
- def get_stride(self) -> Sequence[_IntLike]:
- raise NotImplementedError(type(self).__name__)
- def maybe_get_stride(self) -> Optional[Sequence[_IntLike]]:
- try:
- return self.get_stride()
- except NotImplementedError:
- return None
- def get_name(self) -> str:
- raise NotImplementedError(type(self).__name__)
- def maybe_get_name(self) -> Optional[str]:
- try:
- return self.get_name()
- except NotImplementedError:
- return None
- def is_input_buffer(self) -> bool:
- try:
- return self.get_name() in V.graph.graph_inputs
- except NotImplementedError:
- return False
- def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
- return False
- def mark_reuse(self, users: int) -> None:
- pass
- def realize_hint(self) -> None:
- pass
- def unwrap_view(self) -> IRNode:
- raise NotImplementedError(type(self).__name__)
- def freeze_layout(self) -> None:
- raise NotImplementedError(type(self).__name__)
- def freeze_layout_with_stride_order(
- self, order: Sequence[int], allow_padding: bool = False
- ) -> None:
- raise NotImplementedError(type(self).__name__)
- def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
- raise NotImplementedError(type(self).__name__)
- def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None:
- raise NotImplementedError(type(self).__name__)
- def freeze_layout_with_exact_strides(
- self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
- ) -> None:
- raise NotImplementedError(type(self).__name__)
- def get_read_writes(self) -> dependencies.ReadWrites:
- raise NotImplementedError(type(self).__name__)
- def get_reads(self) -> OrderedSet[Dep]:
- return self.get_read_writes().reads
- def num_reads(self) -> int:
- return len(self.get_reads())
- def get_storage_numel(self) -> _IntLike:
- raise NotImplementedError(type(self).__name__)
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- raise NotImplementedError(type(self).__name__)
- def get_reduction_type(self) -> Optional[str]:
- raise NotImplementedError(type(self).__name__)
- def get_reduction_size(self) -> Sequence[Expr]:
- raise NotImplementedError(type(self).__name__)
- def is_extern(self) -> bool:
- return False
- def is_no_op(self) -> bool:
- return False
- def constant_to_device(self, device: torch.device) -> IRNode:
- raise NotImplementedError(type(self).__name__)
- def get_mutation_names(self) -> Sequence[str]:
- raise NotImplementedError(type(self).__name__)
- def get_operation_name(self) -> str:
- raise NotImplementedError(type(self).__name__)
- def get_inputs_that_alias_output(self) -> Sequence[str]:
- raise NotImplementedError(type(self).__name__)
- if TYPE_CHECKING:
- @property
- def dtype(self) -> torch.dtype: ...
- @ir_dataclass(frozen=False)
- class Operation:
- def __post_init__(self) -> None:
- self.operation_name: Optional[str] = None
- def get_device(self) -> Optional[torch.device]:
- raise NotImplementedError
- def get_origin_node(self) -> Optional[torch.fx.Node]:
- assert hasattr(self, "origin_node")
- return self.origin_node
- def get_origins(self) -> OrderedSet[Any]:
- assert hasattr(self, "origins")
- return self.origins
- def get_operation_name(self) -> str:
- assert self.operation_name is not None
- return self.operation_name
- def is_extern(self) -> bool:
- return False
- def is_no_op(self) -> bool:
- return False
- def get_read_writes(self) -> dependencies.ReadWrites:
- raise NotImplementedError
- def is_user_of(self, name: str) -> bool:
- return name in self.get_read_names()
- def get_read_names(self) -> OrderedSet[str]:
- return OrderedSet(dep.name for dep in self.get_reads())
- def get_reads(self) -> OrderedSet[Dep]:
- return self.get_read_writes().reads
- def get_outputs(self) -> list[Buffer]:
- raise NotImplementedError
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- """
- When unbacked_only=True:
- Returns the unbacked symbols which are required to be in scope in
- order to successfully perform codegen for this buffer. For example,
- a buffer that corresponds to an extern kernel call that takes i0 as
- an argument would return {i0} here. This is used to generate necessary
- dependencies that ensure we actually bind i0 in codegen before you
- try to use it.
- Note that this is NOT transitive; in particular, if this buffer takes
- in as input another buffer with dynamic shape (e.g., (i0,)), we will
- not report it here, because you will already have a dependency
- on that buffer, which will eventually have a dependency on i0 if
- necessary.
- When unbacked_only=False:
- Similar to `unbacked_only=True` but including all free symbols
- instead of only free unbacked symbols.
- """
- return OrderedSet()
- def get_workspace_size(self) -> int:
- """
- Gets extra global memory size needed by this buffer.
- Some algorithms (e.g. group gemm) may require extra global memory in the generated code.
- """
- return 0
- @ir_dataclass
- class Loops(IRNode):
- device: torch.device
- dtype: torch.dtype
- inner_fn: Callable[..., Any]
- ranges: Sequence[_IntLike]
- @cache_on_self_and_args("Loops")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return OrderedSet().union(
- *(get_free_symbols(e, unbacked_only) for e in self.ranges),
- self.inner_fn_free_symbols(unbacked_only),
- )
- def _to_str(self, names: Sequence[str]) -> str:
- return self.str_helper(
- [
- f"'{self.device.type}'",
- str(self.dtype),
- self.inner_fn_str(),
- ]
- + [f"{name}={getattr(self, name)}" for name in names]
- + [f"origin_node={self.origin_node!r}"]
- )
- def __post_init__(self) -> None:
- super().__post_init__()
- def __str__(self) -> str:
- return self._to_str(("ranges",))
- __repr__ = __str__
- def get_device(self) -> Optional[torch.device]:
- return self.device
- def get_origin_node(self) -> Optional[torch.fx.Node]:
- return self.origin_node
- def get_size(self) -> Sequence[Expr]:
- return self.ranges
- def get_pointwise_size(self) -> Sequence[Expr]:
- return self.ranges
- @classmethod
- def create(
- cls, *args: Any, **kwargs: Any
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- origin_node = kwargs.pop("origin_node", None)
- tb = kwargs.pop("traceback", None)
- r = cls(*args, **kwargs)
- # Need to explicitly set origin_node here to propagate it down.
- # todo(chilli): I think it would be better for IRNode to directly set
- # origin_node
- r._post_init_setattr("origin_node", origin_node)
- r._post_init_setattr("traceback", tb or r.traceback)
- return TensorBox.create(r)
- @staticmethod
- def _index(ranges: Sequence[_IntLike], prefix: SymT = SymT.INDEX) -> Sequence[Expr]:
- return [
- sympy.S.Zero if s == 1 else sympy_index_symbol_with_prefix(prefix, n)
- for n, s in enumerate(ranges)
- ]
- @cache_on_self
- def inner_fn_opcount(self) -> OpCountResult:
- opcounter = OpCounterCSE(V.MockHandler())
- with (
- V.set_ops_handler(opcounter),
- patch.object(FlexibleLayout, "allow_indexing", True),
- ):
- self.inner_fn(*self.inner_fn_args())
- return opcounter.getvalue()
- def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]:
- return (self._index(self.ranges),)
- @cache_on_self
- def inner_fn_str(self) -> str:
- return V.KernelFormatterHandler.ir_to_string(
- self.inner_fn, *self.inner_fn_args()
- )
- def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
- if threshold is None:
- threshold = 0
- threshold = max(threshold, config.realize_opcount_threshold)
- return self.inner_fn_opcount().num_ops > threshold
- def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
- index = self._index(self.ranges)
- return extract_free_symbols(self.inner_fn, index, unbacked_only=unbacked_only)
- def get_reads(self) -> OrderedSet[Dep]:
- with patch.object(FlexibleLayout, "allow_indexing", True):
- if self.get_reduction_type():
- return extract_read_writes(
- self.make_loader(),
- self.get_size(),
- self.get_reduction_size(),
- ).reads
- else:
- return extract_read_writes(
- self.make_loader(),
- self.get_size(),
- ).reads
- def get_read_names(self) -> OrderedSet[str]:
- return OrderedSet(self.inner_fn_opcount().read_buffers)
- def num_reads(self) -> int:
- return len(self.inner_fn_opcount().read_buffers)
- def get_reduction_size(self) -> Sequence[Expr]:
- raise NotImplementedError(
- f"get_reduction_size() is not implemented by {type(self)}!"
- )
- def get_reduction_type(self) -> Optional[str]:
- raise NotImplementedError(
- f"get_reduction_type() is not implemented by {type(self)}!"
- )
- def constant_to_device(self, device: torch.device) -> IRNode:
- raise NotImplementedError(
- f"constant_to_device() is not implemented by {type(self)}!"
- )
- def nop_loader_fn(idx: Union[Expr, Sequence[Expr]], *, dtype: torch.dtype) -> OpsValue:
- if dtype.is_floating_point:
- return ops.constant(float("nan"), dtype)
- else:
- return ops.constant(0, dtype)
- @ir_dataclass
- class Pointwise(Loops):
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- # Make zero-element loops into a no-op
- if self.is_zero_elements():
- return partial(nop_loader_fn, dtype=self.dtype)
- return self.inner_fn
- def get_reduction_size(self) -> Sequence[sympy.Expr]:
- return []
- def get_reduction_type(self) -> Optional[str]:
- return None
- def store_output(
- self,
- output_name: Optional[str],
- indexer: Callable[[Sequence[Expr]], Never],
- vars: Sequence[Expr],
- ) -> None:
- loader = self.make_loader()
- return ops.store(output_name or "unnamed", indexer(vars), loader(vars))
- def constant_to_device(self, device: torch.device) -> IRNode:
- """Move this to a given device. Requires that all reads are to constants."""
- loader = self.make_loader()
- loader = patch.object(ConstantBuffer, "override_device", device)(loader)
- return Pointwise(
- device=device,
- dtype=self.dtype,
- inner_fn=loader,
- ranges=self.ranges,
- )
- @ir_dataclass
- class Scatter(Pointwise):
- output_indexer: Callable[[Sequence[Expr]], Expr]
- scatter_mode: StoreMode = None
- def constant_to_device(self, device: torch.device) -> IRNode:
- """Move this to a given device. Requires that all reads are to constants."""
- loader = self.make_loader()
- loader = patch.object(ConstantBuffer, "override_device", device)(loader)
- return Scatter(
- device=device,
- dtype=self.dtype,
- inner_fn=loader,
- ranges=self.ranges,
- output_indexer=self.output_indexer,
- scatter_mode=self.scatter_mode,
- )
- def store_output(
- self,
- output_name: Optional[str],
- indexer: Callable[[Sequence[Expr]], Never],
- vars: Sequence[Expr],
- ) -> Any:
- loader = self.make_loader()
- if output_name is None:
- output_name = "unnamed"
- return ops.store(
- output_name,
- indexer(self.output_indexer(vars)),
- loader(vars),
- mode=self.scatter_mode,
- )
- REDUCTION_COMBINE_FN: dict[str, Callable[..., OpsValue]] = {
- "any": ops_wrapper("logical_or"),
- "max": ops_wrapper("maximum"),
- "min": ops_wrapper("minimum"),
- "prod": ops_wrapper("mul"),
- "sum": ops_wrapper("add"),
- "xor_sum": ops_wrapper("bitwise_xor"),
- }
- def get_reduction_combine_fn(
- reduction_type: str, dtype: torch.dtype, arg_break_ties_left: bool = True
- ) -> Callable[..., object]:
- if reduction_type in REDUCTION_COMBINE_FN:
- return REDUCTION_COMBINE_FN[reduction_type]
- elif reduction_type in ("argmax", "argmin"):
- def argmax_combine_fn(
- a: tuple[object, object], b: tuple[object, object]
- ) -> tuple[OpsValue, OpsValue]:
- a_value, a_index = a
- b_value, b_index = b
- if reduction_type == "argmin":
- mask = ops.lt(a_value, b_value)
- else:
- mask = ops.gt(a_value, b_value)
- equal = ops.eq(a_value, b_value)
- if is_float_dtype(dtype):
- a_isnan = ops.ne(a_value, a_value)
- b_isnan = ops.ne(b_value, b_value)
- mask = ops.logical_or(mask, ops.gt(a_isnan, b_isnan))
- equal = ops.logical_or(equal, ops.logical_and(a_isnan, b_isnan))
- tie = (
- ops.lt(a_index, b_index)
- if arg_break_ties_left
- else ops.gt(a_index, b_index)
- )
- mask = ops.logical_or(mask, ops.logical_and(equal, tie))
- return (
- ops.where(mask, a_value, b_value),
- ops.where(mask, a_index, b_index),
- )
- return argmax_combine_fn
- elif reduction_type == "welford_combine":
- def welford_combine_fn(
- a: tuple[OpsValue, OpsValue, OpsValue],
- b: tuple[OpsValue, OpsValue, OpsValue],
- ) -> tuple[OpsValue, OpsValue, OpsValue]:
- a_mean, a_m2, a_weight = a
- b_mean, b_m2, b_weight = b
- delta = b_mean - a_mean
- new_weight = a_weight + b_weight
- w2_over_w = b_weight / new_weight
- return (
- a_mean + delta * w2_over_w,
- a_m2 + b_m2 + delta * delta * a_weight * w2_over_w,
- new_weight,
- )
- return welford_combine_fn
- else:
- raise NotImplementedError(f"unknown reduction_type={reduction_type}")
- @ir_dataclass
- class Reduction(Loops):
- reduction_ranges: Sequence[_IntLike]
- reduction_type: ReductionType
- # self.dtype represents the dst dtype
- src_dtype: torch.dtype
- reduction_hint: ReductionHint
- def __str__(self) -> str:
- return self._to_str(("ranges", "reduction_ranges", "reduction_type"))
- __repr__ = __str__
- @cache_on_self_and_args("Reduction")
- def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
- return super().get_free_symbol_uses(unbacked_only) | OrderedSet().union(
- *(get_free_symbols(e, unbacked_only) for e in self.reduction_ranges)
- )
- def get_reduction_size(self) -> Sequence[Expr]:
- return self.reduction_ranges
- def get_reduction_type(self) -> Optional[str]:
- return self.reduction_type
- def store_reduction(
- self,
- output_name: Optional[str],
- indexer: Callable[[Sequence[Expr]], Never],
- vars: Sequence[Expr],
- reduction_vars: Sequence[Symbol],
- ) -> None:
- value = ops.reduction(
- self.dtype,
- self.src_dtype,
- self.reduction_type,
- self.inner_fn(vars, reduction_vars),
- )
- ops.store_reduction(output_name or "unnamed", indexer(vars), value)
- def index_length(self) -> int:
- return len(self.ranges) + len(self.reduction_ranges)
- def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
- index = self._index(self.ranges)
- rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
- return (index, rindex)
- def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
- index = self._index(self.ranges)
- rindex = self._index(self.reduction_ranges, SymT.R0_INDEX)
- return extract_free_symbols(
- self.inner_fn, index, rindex, unbacked_only=unbacked_only
- )
- def constant_to_device(self, device: torch.device) -> IRNode:
- """Move this to a given device. Requires that all reads are to constants."""
- loader = self.make_loader()
- loader = patch.object(ConstantBuffer, "override_device", device)(loader)
- return Reduction(
- device=device,
- dtype=self.dtype,
- inner_fn=loader,
- ranges=self.ranges,
- reduction_ranges=self.reduction_ranges,
- reduction_type=self.reduction_type,
- src_dtype=self.src_dtype,
- reduction_hint=ReductionHint.DEFAULT,
- )
- @staticmethod
- def num_splits(
- device: torch.device,
- dst_dtype: torch.dtype,
- src_dtype: torch.dtype,
- inner_fn: Callable[_P, OpsValue],
- ranges: Sequence[_IntLike],
- reduction_ranges: Sequence[_IntLike],
- reduction_type: Union[ReductionType, Literal["scan"]],
- reduction_numel: Expr,
- input_node: Optional[IRNode] = None,
- ) -> tuple[ReductionHint, _IntLike]:
- reduction_numel_hint = V.graph.sizevars.symbolic_hint(reduction_numel)
- numel_hint = V.graph.sizevars.symbolic_hint(sympy_product(ranges))
- should_split = reduction_type == "scan" or (
- not V.graph.has_feature(device, BackendFeature.REDUCE_TO_SINGLE_ELEMENT)
- and reduction_type
- not in (
- "argmax",
- "argmin",
- )
- and config.split_reductions
- )
- if not (_is_static(reduction_numel_hint) and _is_static(numel_hint)):
- # We don't support unbacked symints
- return ReductionHint.DEFAULT, 1
- props = DeviceProperties.create(device)
- num_sm = props.multi_processor_count
- min_elements_per_thread = 32
- if should_split:
- inner_reduction_splits: Callable[[int, int], int] = functools.partial(
- V.choices.reduction_split_factor, device, inner_reduction=True
- )
- outer_reduction_splits: Callable[[int, int], int] = functools.partial(
- V.choices.reduction_split_factor, device, inner_reduction=False
- )
- else:
- def inner_reduction_splits(
- reduction_numel_hint: int,
- numel_hint: int,
- ) -> int:
- return 1
- outer_reduction_splits = inner_reduction_splits
- # easy cases
- if numel_hint == 1:
- split = inner_reduction_splits(reduction_numel_hint, numel_hint)
- if split == 1:
- # No need to split.
- return ReductionHint.INNER, split
- if input_node is not None and isinstance(input_node, TensorBox):
- with patch.object(FlexibleLayout, "allow_indexing", True):
- (
- new_ranges,
- new_reduction_ranges,
- ) = extract_input_node_reduction_ranges(input_node)
- if new_ranges is not None and new_reduction_ranges is not None:
- extracted_numel_hint = V.graph.sizevars.symbolic_hint(
- sympy_product(new_ranges + new_reduction_ranges)
- )
- if reduction_numel_hint == extracted_numel_hint:
- log.debug(
- "Use previous IRNode's range and reduction_ranges instead of split. "
- "current ranges: %s, current reduction ranges: %s, current split: %d, "
- "new ranges: %s, new reduction ranges: %s",
- ranges,
- reduction_ranges,
- split,
- new_ranges,
- new_reduction_ranges,
- )
- # If the input_node or its dependent nodes are also Reduction nodes,
- # use reduction_sizes of this node or its dependent nodes directly.
- return ReductionHint.INNER, -1
- return ReductionHint.INNER, split
- if (
- reduction_numel_hint <= min_elements_per_thread
- or numel_hint >= num_sm * 2 * 32
- ):
- return ReductionHint.DEFAULT, 1
- r = Reduction(
- device=device,
- dtype=dst_dtype,
- inner_fn=inner_fn,
- ranges=ranges,
- reduction_ranges=reduction_ranges,
- reduction_type=reduction_type if reduction_type != "scan" else "sum",
- src_dtype=src_dtype,
- reduction_hint=ReductionHint.DEFAULT,
- )
- def get_read_indices(r: Reduction) -> tuple[Sequence[Expr], bool]:
- device = r.get_device()
- assert device is not None
- cb = ComputedBuffer(
- name=None,
- layout=FlexibleLayout(
- device=device,
- dtype=r.get_dtype(),
- size=r.get_size(),
- ),
- data=r,
- )
- read_writes = cb.get_read_writes()
- # try finding the full size producer
- # TODO this will fail for something like ((1, N) * (N, 1)).sum()
- # this would also possibly be wrong for producers with the different contiguity but we hope those cases are rare
- assert read_writes.range_vars is not None
- range_vars = [
- r
- for r in read_writes.range_vars
- if isinstance(r, Expr) and not isinstance(r, sympy.Number)
- ]
- indices = []
- changed = False
- for md in sorted(read_writes.reads, key=lambda x: x.name):
- if all(r in md.index.free_symbols for r in range_vars):
- indices.append(md.index)
- if md.name in V.graph.name_to_buffer:
- buf = V.graph.name_to_buffer[md.name]
- original_stride = getattr(buf.layout, "stride", None)
- buf.decide_layout()
- if getattr(buf.layout, "stride", None) != original_stride:
- changed = True
- return indices, changed
- indices, changed = get_read_indices(r)
- if changed:
- indices, _ = get_read_indices(r)
- if len(indices) == 0:
- # TODO determine splits when all inputs are broadcast
- return ReductionHint.DEFAULT, 1
- (_, reduction_vars), ranges1 = dependencies.index_vars_squeeze(
- r.get_size(), r.get_reduction_size()
- )
- num_outer = 0
- num_inner = 0
- for i in indices:
- j = V.graph.sizevars.simplify_with_ranges(i, ranges1)
- strides = V.graph.sizevars.stride_hints(
- j, reduction_vars, list(ranges1.keys())
- )
- outer = all(s > 1 for s in strides)
- if outer:
- num_outer += 1
- else:
- num_inner += 1
- if num_inner > num_outer:
- return ReductionHint.INNER, inner_reduction_splits(
- reduction_numel_hint, numel_hint
- )
- else:
- return ReductionHint.OUTER, outer_reduction_splits(
- reduction_numel_hint, numel_hint
- )
- @staticmethod
- def _unroll_reduction_fn(
- inner_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], OpsValue],
- reduction_ranges: Sequence[_IntLike],
- reduction_type: str,
- src_dtype: torch.dtype,
- ) -> Callable[[Sequence[_IntLike]], OpsValue]:
- """Convert inner_fn from a reduction to an pointwise"""
- reduction_ranges = V.graph.sizevars.guard_int_seq(reduction_ranges)
- combine_fn = get_reduction_combine_fn(reduction_type, src_dtype)
- def fn(index: Sequence[_IntLike]) -> Any:
- return functools.reduce(
- combine_fn,
- (
- value_fn(index, rindex)
- for rindex in itertools.product(
- *[range(x) for x in reduction_ranges]
- )
- ),
- )
- value_fn: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Any]
- if reduction_type in ("argmin", "argmax"):
- flatten_index = _fixed_indexer(
- reduction_ranges,
- FlexibleLayout.contiguous_strides(reduction_ranges),
- )
- def value_fn(
- index: Sequence[_IntLike], rindex: Sequence[_IntLike]
- ) -> tuple[OpsValue, OpsValue]:
- rindex = [sympy.expand(i) for i in rindex]
- return (
- inner_fn(index, rindex),
- ops.index_expr(flatten_index(rindex), torch.int64),
- )
- return lambda index: fn(index)[1]
- else:
- value_fn = inner_fn
- return fn
- @classmethod
- def create(
- cls,
- device: torch.device,
- dst_dtype: torch.dtype,
- src_dtype: torch.dtype,
- inner_fn: Callable[..., Any],
- ranges: Sequence[Expr],
- reduction_ranges: Sequence[Expr],
- reduction_type: ReductionType,
- reduction_hint: ReductionHint = ReductionHint.DEFAULT,
- input_node: Optional[IRNode] = None,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
- if reduction_numel == 0:
- # N.B. This is a hack to generate the literal of the given type
- # Ideally, we should be fixing `def constant` in triton.py
- # but it breaks due to hardcoded dtypes in other places
- def py_cnst(val: object) -> Union[bool, float, int]:
- if dst_dtype == torch.bool:
- return bool(val)
- elif dst_dtype.is_floating_point:
- assert isinstance(val, SupportsFloat), type(val)
- return float(val)
- else:
- assert isinstance(val, SupportsInt), type(val)
- return int(val)
- rtypes_to_inits = {
- "sum": py_cnst(0),
- "xor_sum": py_cnst(0),
- "prod": py_cnst(1),
- "any": py_cnst(0),
- # "all" is desugared to `!any(!val)`
- }
- assert reduction_type in rtypes_to_inits.keys(), (
- f"{reduction_type} not supported for zero-dimension tensors!"
- )
- def const_fn(index: int) -> OpsValue:
- return ops.constant(rtypes_to_inits[reduction_type], dst_dtype)
- return Pointwise.create(
- device=device,
- dtype=src_dtype,
- inner_fn=const_fn,
- ranges=list(ranges),
- )
- if reduction_numel == 1:
- # this reduction is actually a pointwise op
- if reduction_type in ("argmin", "argmax"):
- def fn(index: int) -> OpsValue:
- return ops.constant(0, dst_dtype)
- else:
- def fn(index: int) -> OpsValue:
- reduction_index = [sympy.S.Zero for _ in reduction_ranges]
- return inner_fn(index, reduction_index)
- return Pointwise.create(
- device=device, dtype=dst_dtype, inner_fn=fn, ranges=ranges
- )
- if (
- isinstance(reduction_numel, Integer)
- and V.graph.sizevars.size_hint_or_throw(reduction_numel)
- < config.unroll_reductions_threshold
- and (sympy_product(ranges) != 1 or is_gpu(device.type))
- ):
- # NB: This works around https://github.com/pytorch/pytorch/issues/140457
- # since turning reductions into pointwise ops can exacerbate this problem
- return Pointwise.create(
- device=device,
- dtype=dst_dtype,
- inner_fn=cls._unroll_reduction_fn(
- inner_fn, reduction_ranges, reduction_type, src_dtype
- ),
- ranges=ranges,
- )
- # triton doesn't support reduce to single element well, so break it up
- hint, split = cls.num_splits(
- device,
- dst_dtype,
- src_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- reduction_type,
- reduction_numel,
- input_node,
- )
- def _maybe_increase_split(split: int) -> int:
- # don't apply min_num_split constraint for static shape case.
- if _is_static(reduction_numel):
- return split
- if split > 1:
- return max(split, config.min_num_split)
- else:
- return split
- split = _maybe_increase_split(split)
- # intermediate reduction in split can contain complex indexing,
- # and num_splits will fail to correctly set the hint
- # reuse the passed hint if available
- if reduction_hint == ReductionHint.DEFAULT:
- reduction_hint = hint
- if split == -1:
- assert input_node is not None
- with patch.object(FlexibleLayout, "allow_indexing", True):
- new_ranges, new_reduction_ranges = extract_input_node_reduction_ranges(
- input_node
- )
- assert new_ranges is not None
- assert new_reduction_ranges is not None
- return cls.create_multilayer_existing_ranges(
- device,
- dst_dtype,
- src_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- new_ranges,
- new_reduction_ranges,
- reduction_type,
- reduction_hint,
- )
- elif split > 1:
- # triton doesn't support reduce to single element well, so break it up
- return cls.create_multilayer(
- device,
- dst_dtype,
- src_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- reduction_type,
- split,
- reduction_hint,
- input_node,
- )
- return TensorBox.create(
- Reduction(
- device=device,
- dtype=dst_dtype,
- inner_fn=inner_fn,
- ranges=ranges,
- reduction_ranges=reduction_ranges,
- reduction_type=reduction_type,
- src_dtype=src_dtype,
- reduction_hint=reduction_hint,
- )
- )
- @staticmethod
- def default_accumulator(
- reduction_type: str, dtype: torch.dtype
- ) -> Union[_NumLike, Sequence[_NumLike]]:
- if reduction_type in ("max", "argmax"):
- if is_float_dtype(dtype):
- return float("-inf")
- elif is_boolean_dtype(dtype):
- return False
- else:
- return torch.iinfo(dtype).min
- if reduction_type in ("min", "argmin"):
- if is_float_dtype(dtype):
- return float("inf")
- elif is_boolean_dtype(dtype):
- return True
- else:
- return torch.iinfo(dtype).max
- zero = False if is_boolean_dtype(dtype) else 0
- one = True if is_boolean_dtype(dtype) else 1
- return {
- "sum": zero,
- "prod": one,
- "xor_sum": zero,
- "any": zero,
- "welford_reduce": (zero, zero, zero),
- "welford_combine": (zero, zero, zero),
- "online_softmax_reduce": (float("-inf"), zero),
- }[reduction_type]
- @staticmethod
- def default_value(
- reduction_type: str, dtype: torch.dtype
- ) -> Union[_NumLike, Sequence[_NumLike]]:
- if reduction_type == "welford_reduce":
- return 0
- return Reduction.default_accumulator(reduction_type, dtype)
- @staticmethod
- def _multilayer_second_step_hint(
- split: _IntLike, numel_hint: int, reduction_hint: ReductionHint
- ) -> ReductionHint:
- if split == -1:
- return reduction_hint
- if split <= 512 and numel_hint <= 512 and reduction_hint == ReductionHint.OUTER:
- return ReductionHint.OUTER_TINY
- if (
- split <= 1024
- and numel_hint <= 256
- and reduction_hint == ReductionHint.OUTER
- ):
- return ReductionHint.OUTER_TINY
- return reduction_hint
- @classmethod
- def check_for_split_dense_dim_reindexing(
- cls, reduction_numel: _IntLike, input_node: Optional[IRNode]
- ) -> Optional[int]:
- """
- If we are reducing over the full tensor, and it is non-dense in the last dimension,
- reindex so we reduce over the dense dimension. initially just handle complete
- reduction case
- """
- if input_node is None:
- return None
- if not V.graph.sizevars.statically_known_equals(
- input_node.get_numel(), reduction_numel
- ):
- return None
- input_node.realize()
- try:
- # finalize layout
- as_storage_and_layout(input_node)
- except NotImplementedError:
- return None
- strides = input_node.get_stride()
- for i, s in enumerate(strides[:-1]):
- if V.graph.sizevars.statically_known_equals(s, 1):
- return i
- return None
- @classmethod
- def _multilayer_wrap_loader(
- cls,
- loader: Callable[..., OpsValue],
- reduction_ranges: Sequence[_IntLike],
- reduction_numel: _IntLike,
- split: _IntLike,
- block_size: _IntLike,
- default: Union[_NumLike, Sequence[_NumLike]],
- input_node: Optional[IRNode] = None,
- ) -> Callable[..., object]:
- dense_index = cls.check_for_split_dense_dim_reindexing(
- reduction_numel, input_node
- )
- reindex = View.dynamic_reshape_indexer(
- reduction_ranges, [reduction_numel], dense_index
- )
- need_mask = not V.graph.sizevars.statically_known_true(
- sympy.Eq(reduction_numel % split, 0)
- )
- def wrapper_fn(
- index: Sequence[Symbol], reduction_index: Sequence[Symbol]
- ) -> OpsValue:
- (reduction_index,) = reduction_index
- *new_index, reduction_block = index
- indices = block_size * reduction_block + reduction_index
- def body() -> OpsValue:
- return loader(new_index, reindex([indices]))
- if need_mask:
- index_dtype = dtype_from_size(reduction_numel)
- mask = ops.lt(
- ops.index_expr(indices, index_dtype),
- ops.index_expr(reduction_numel, index_dtype),
- )
- return ops.masked(mask, body, default)
- else:
- return body()
- return wrapper_fn
- @classmethod
- def _multilayer_wrap_loader_existing_ranges(
- cls,
- loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
- original_ranges: Sequence[Expr],
- original_reduction_ranges: Sequence[Expr],
- new_ranges: Sequence[Integer],
- new_reduction_ranges: Sequence[Integer],
- ) -> Callable[[Sequence[sympy.Expr], Sequence[sympy.Expr]], OpsValue]:
- assert all(r == 1 for r in original_ranges), (
- f"Only enabled for numel_hint == 1, found {original_ranges=}"
- )
- reindex = View.dynamic_reshape_indexer(
- original_reduction_ranges, tuple(new_ranges) + tuple(new_reduction_ranges)
- )
- def wrapper_fn(
- merged_index: Sequence[Expr],
- new_reduction_index: Sequence[Expr],
- ) -> OpsValue:
- original_idx = merged_index[: len(original_ranges)]
- new_index = merged_index[len(original_ranges) :]
- return loader(
- original_idx,
- reindex(tuple(new_index) + tuple(new_reduction_index)),
- )
- return wrapper_fn
- @classmethod
- def create_multilayer_helper(
- cls,
- device: torch.device,
- dst_dtype: torch.dtype,
- src_dtype: torch.dtype,
- wrapper_fn: Callable[..., Any],
- original_ranges: Sequence[Expr],
- original_reduction_ranges: Sequence[Expr],
- new_ranges: list[Expr],
- new_reduction_ranges: list[Integer],
- reduction_type: ReductionType,
- split: _IntLike,
- reduction_hint: ReductionHint,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- """
- Break a large reduction up into multiple smaller reductions
- recursively
- """
- # triton will automatically compute reductions in fp32 if reducing over fp16/bf16
- # within the kernel. keep the intermediate in fp32 so as to keep the whole reduction
- # in fp32 and not reduce precision by breaking up the kernel into multiple layers
- intermediate_dtype = (
- dst_dtype
- if dst_dtype not in (torch.float16, torch.bfloat16)
- else torch.float
- )
- intermediate = Reduction.create(
- device,
- intermediate_dtype,
- src_dtype,
- wrapper_fn,
- new_ranges,
- new_reduction_ranges,
- reduction_type,
- reduction_hint,
- )
- intermediate.realize()
- intermediate_loader = intermediate.make_loader()
- def intermediate_fn(
- index: Sequence[_IntLike], reduction_index: Sequence[_IntLike]
- ) -> OpsValue:
- return intermediate_loader([*index, *reduction_index])
- numel_hint = V.graph.sizevars.size_hint(sympy_product(original_ranges))
- reduction_hint = cls._multilayer_second_step_hint(
- split, numel_hint, reduction_hint
- )
- assert original_ranges == new_ranges[: len(original_ranges)]
- return TensorBox.create(
- Reduction(
- device=device,
- dtype=dst_dtype,
- inner_fn=intermediate_fn,
- ranges=original_ranges,
- reduction_ranges=new_ranges[len(original_ranges) :],
- reduction_type=reduction_type,
- src_dtype=src_dtype,
- reduction_hint=reduction_hint,
- )
- )
- @classmethod
- def create_multilayer(
- cls,
- device: torch.device,
- dst_dtype: torch.dtype,
- src_dtype: torch.dtype,
- inner_fn: Callable[..., Any],
- ranges: Sequence[Expr],
- reduction_ranges: Sequence[Expr],
- reduction_type: ReductionType,
- split: _IntLike,
- reduction_hint: ReductionHint,
- input_node: Optional[IRNode] = None,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- """
- Break a large reduction up into multiple smaller reductions
- recursively
- """
- # TODO(jansel): realize the reduction so we can do dynamic indexing
- reduction_numel = sympy_product(reduction_ranges)
- block_size = FloorDiv(reduction_numel + (split - 1), split)
- default = cls.default_value(reduction_type, dst_dtype)
- wrapper_fn = cls._multilayer_wrap_loader(
- inner_fn,
- reduction_ranges,
- reduction_numel,
- split,
- block_size,
- default,
- input_node,
- )
- return cls.create_multilayer_helper(
- device,
- dst_dtype,
- src_dtype,
- wrapper_fn,
- ranges,
- reduction_ranges,
- [*ranges, split],
- [block_size],
- reduction_type,
- split,
- reduction_hint,
- )
- @classmethod
- def create_multilayer_existing_ranges(
- cls,
- device: torch.device,
- dst_dtype: torch.dtype,
- src_dtype: torch.dtype,
- inner_fn: Callable[..., Any],
- original_ranges: Sequence[Expr],
- original_reduction_ranges: Sequence[Expr],
- new_ranges: list[Integer],
- new_reduction_ranges: list[Integer],
- reduction_type: ReductionType,
- reduction_hint: ReductionHint,
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- """
- Break a large reduction up into multiple smaller reductions
- recursively
- """
- wrapper_fn = cls._multilayer_wrap_loader_existing_ranges(
- inner_fn,
- original_ranges,
- original_reduction_ranges,
- new_ranges,
- new_reduction_ranges,
- )
- return cls.create_multilayer_helper(
- device,
- dst_dtype,
- src_dtype,
- wrapper_fn,
- original_ranges,
- original_reduction_ranges,
- [*original_ranges, *new_ranges],
- new_reduction_ranges,
- reduction_type,
- -1,
- reduction_hint,
- )
- def _fixed_indexer(
- size: Sequence[int],
- stride: Optional[Sequence[int]] = None,
- offset: Expr = Integer(0),
- ) -> Callable[[Sequence[Expr]], Expr]:
- """A closure containing math to read a given element"""
- def indexer(index: Sequence[int]) -> int:
- assert stride is not None and len(index) == len(stride)
- assert len(index) == len(size)
- result = offset
- for idx, st, sz in zip(index, stride, size):
- if sz != 1:
- result = result + idx * st
- return result
- return indexer
- INNER_FN_TY: TypeAlias = Callable[[Sequence[Expr], Sequence[Expr]], OpsValue]
- class MultiOutputReduction(Reduction):
- output_index: int
- def __init__(
- self,
- device: torch.device,
- dst_dtype: torch.dtype,
- inner_fns: Union[INNER_FN_TY, Sequence[INNER_FN_TY]],
- ranges: Sequence[Integer],
- reduction_ranges: Sequence[Integer],
- reduction_type: ReductionType,
- src_dtype: torch.dtype,
- reduction_hint: ReductionHint,
- output_index: int,
- ):
- if callable(inner_fns):
- inner_fns = (inner_fns,)
- loader: Callable[[Sequence[Expr], Sequence[Expr]], Any]
- if len(inner_fns) == 1:
- loader = inner_fns[0]
- else:
- def loader(
- idx: Sequence[Expr], reduction_idx: Sequence[Expr]
- ) -> tuple[OpsValue, ...]:
- return tuple(fn(idx, reduction_idx) for fn in inner_fns)
- super().__init__(
- device=device,
- dtype=dst_dtype,
- inner_fn=loader,
- ranges=ranges,
- reduction_ranges=reduction_ranges,
- reduction_type=reduction_type,
- src_dtype=src_dtype,
- reduction_hint=reduction_hint,
- )
- self.output_index = output_index
- def store_reduction(
- self,
- output_name: Optional[str],
- indexer: Callable[[Sequence[Expr]], Never],
- vars: Sequence[Expr],
- reduction_vars: Sequence[Symbol],
- ) -> Any:
- values = ops.reduction(
- self.dtype,
- self.src_dtype,
- self.reduction_type,
- self.inner_fn(vars, reduction_vars),
- )
- assert isinstance(values, (tuple, list)), type(values)
- value = values[self.output_index]
- return ops.store_reduction(output_name or "unnamed", indexer(vars), value)
- class OnlineSoftmaxReduction(MultiOutputReduction):
- @classmethod
- def create( # type: ignore[override]
- cls,
- device: torch.device,
- dst_dtype: torch.dtype,
- src_dtype: torch.dtype,
- inner_fn: Callable[..., Any],
- ranges: Sequence[Expr],
- reduction_ranges: Sequence[Expr],
- num_output: int,
- reduction_hint: ReductionHint = ReductionHint.DEFAULT,
- input_node: Optional[IRNode] = None,
- ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]:
- """
- Create the reduction disregarding splitting.
- """
- results = tuple(
- TensorBox.create(
- MultiOutputReduction(
- device,
- dst_dtype,
- inner_fn,
- ranges,
- reduction_ranges,
- "online_softmax_reduce",
- src_dtype,
- reduction_hint,
- output_idx,
- )
- )
- for output_idx in range(num_output)
- )
- for t in results:
- t.realize()
- return results
- class WelfordReduction(MultiOutputReduction):
- @classmethod
- def create( # type: ignore[override]
- cls,
- device: torch.device,
- dtype: torch.dtype,
- inner_fns: Sequence[Callable[..., Any]],
- ranges: list[Integer],
- reduction_ranges: list[Integer],
- reduction_type: ReductionType,
- reduction_hint: ReductionHint = ReductionHint.DEFAULT,
- ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]:
- assert reduction_type in ("welford_reduce", "welford_combine")
- reduction_numel = V.graph.sizevars.simplify(sympy_product(reduction_ranges))
- def const(val: int) -> Union[TensorBox, ShapeAsConstantBuffer]:
- def inner_fn(idx: Sequence[Expr]) -> OpsValue:
- return ops.constant(
- val,
- dtype,
- )
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(ranges),
- )
- if reduction_numel == 0:
- mean = const(0)
- m2 = const(0)
- weight = const(0)
- return mean, m2, weight
- if reduction_numel == 1:
- def copy(
- loader: Callable[[Sequence[Expr], Sequence[Expr]], OpsValue],
- ) -> Union[TensorBox, ShapeAsConstantBuffer]:
- def inner_fn(idx: Sequence[Expr]) -> OpsValue:
- reduction_index = [sympy.S.Zero for _ in reduction_ranges]
- return loader(idx, reduction_index)
- return Pointwise.create(
- device=device,
- dtype=dtype,
- inner_fn=inner_fn,
- ranges=list(ranges),
- )
- if reduction_type == "welford_reduce":
- return copy(inner_fns[0]), const(0), const(1)
- else:
- return tuple(copy(fn) for fn in inner_fns)
- # TODO: Unrolled reduction
- # if (
- # isinstance(reduction_numel, Integer)
- # and V.graph.sizevars.size_hint(reduction_numel)
- # < config.unroll_reductions_threshold
- # and sympy_product(ranges) != 1
- # ):
- # return Pointwise.create(
- # device,
- # dst_dtype,
- # cls._unroll_reduction_fn(
- # inner_fn, reduction_ranges, reduction_type, src_dtype,
- # ),
- # ranges,
- # )
- # triton doesn't support reduce to single element well, so break it up
- hint, split = Reduction.num_splits(
- device,
- dtype,
- dtype,
- inner_fns[0],
- ranges,
- reduction_ranges,
- reduction_type=reduction_type,
- reduction_numel=reduction_numel,
- )
- # intermediate reduction in split can contain complex indexing,
- # and num_splits will fail to correctly set the hint
- # reuse the passed hint if available
- if reduction_hint == ReductionHint.DEFAULT:
- reduction_hint = hint
- if split > 1:
- # triton doesn't support reduce to single element well, so break it up
- return cls.create_multilayer(
- device,
- dtype,
- inner_fns,
- ranges,
- reduction_ranges,
- reduction_type,
- split,
- reduction_hint,
- )
- results = [
- TensorBox.create(
- WelfordReduction(
- device,
- dtype,
- inner_fns,
- ranges,
- reduction_ranges,
- reduction_type,
- dtype,
- reduction_hint,
- output_idx,
- )
- )
- for output_idx in range(3)
- ]
- for t in results:
- t.realize()
- return results
- @staticmethod
- def default_value(
- reduction_type: str, dtype: torch.dtype
- ) -> Union[_NumLike, Sequence[_NumLike]]:
- return (0, 0, 0)
- @classmethod
- def create_multilayer( # type: ignore[override]
- cls,
- device: torch.device,
- dtype: torch.dtype,
- inner_fns: Sequence[Callable[..., Any]],
- ranges: list[Integer],
- reduction_ranges: list[Integer],
- reduction_type: ReductionType,
- split: _IntLike,
- reduction_hint: ReductionHint,
- ) -> Sequence[Union[TensorBox, ShapeAsConstantBuffer]]:
- """
- Break a large reduction up into multiple smaller reductions
- recursively
- """
- reduction_numel = sympy_product(reduction_ranges)
- need_mask = not V.graph.sizevars.statically_known_true(
- sympy.Eq(reduction_numel % split, 0)
- )
- if need_mask and reduction_type != "welford_combine":
- # If we need mask, then "welford_reduce" doesn't work because
- # masked inputs shouldn't count towards the welford weight
- def constant(
- idx: Sequence[Expr], reduction_idx: Sequence[Expr], value: int
- ) -> OpsValue:
- return ops.constant(value, dtype)
- return cls.create_multilayer(
- device=device,
- dtype=dtype,
- inner_fns=(
- inner_fns[0],
- partial(constant, value=0),
- partial(constant, value=1),
- ),
- ranges=ranges,
- reduction_ranges=reduction_ranges,
- reduction_type="welford_combine",
- split=split,
- reduction_hint=reduction_hint,
- )
- block_size = FloorDiv(reduction_numel + (split - 1), split)
- intermediates = WelfordReduction.create(
- device,
- dtype,
- tuple(
- cls._multilayer_wrap_loader(
- loader,
- reduction_ranges,
- reduction_numel,
- split,
- block_size,
- default=0,
- )
- for loader in inner_fns
- ),
- [*ranges, split],
- [block_size],
- reduction_type,
- reduction_hint,
- )
- for i in intermediates:
- i.realize()
- def intermediate_loader_fn(
- index: Sequence[Expr],
- reduction_index: Sequence[Expr],
- loader: Callable[[Sequence[Expr]], OpsValue],
- ) -> OpsValue:
- return loader([*index, *reduction_index])
- numel_hint = V.graph.sizevars.size_hint(sympy_product(ranges))
- reduction_hint = cls._multilayer_second_step_hint(
- split, numel_hint, reduction_hint
- )
- return WelfordReduction.create(
- device,
- dtype,
- tuple(
- partial(intermediate_loader_fn, loader=i.make_loader())
- for i in intermediates
- ),
- ranges,
- [split],
- # welford_reduce turns one input into three outputs, which are combined with welford_combine
- "welford_combine",
- reduction_hint,
- )
- @ir_dataclass
- class Scan(Loops):
- scan_ranges: list[Integer]
- size: list[Integer]
- combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]]
- reindex: Callable[[Sequence[_IntLike], Sequence[_IntLike]], Sequence[_IntLike]]
- reduction_hint: ReductionHint
- output_index: int
- # output_index indexes the following tuples
- dtypes: tuple[torch.dtype, ...]
- inner_fns: tuple[Callable[..., Any], ...]
- # HACK we mimic reduction
- @cache_on_self_and_args("Scan")
- def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
- # TODO: Can combine_fn/reindex close over unbacked symbols? If so, we
- # need to explicitly represent the closure so we can pull out unbacked
- # symbols here
- return (
- super().get_free_symbol_uses(unbacked_only)
- | OrderedSet().union(
- *(get_free_symbols(e, unbacked_only) for e in self.scan_ranges)
- )
- | OrderedSet().union(
- *(get_free_symbols(e, unbacked_only) for e in self.size)
- )
- )
- def __post_init__(self) -> None:
- assert len(self.ranges) + len(self.scan_ranges) == len(self.size)
- super().__post_init__()
- def store_reduction(
- self,
- output_name: Optional[str],
- indexer: Callable[[Sequence[_IntLike]], Never],
- vars: Sequence[Expr],
- scan_vars: Sequence[Symbol],
- ) -> Any:
- idx = self.reindex(vars, scan_vars)
- values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
- result = ops.scan(self.dtypes, self.combine_fn, values)
- return ops.store(
- output_name or "unnamed", indexer(idx), result[self.output_index]
- )
- def get_reduction_type(self) -> Optional[str]:
- # return self.scan_op
- return "custom"
- def get_reduction_size(self) -> Sequence[Expr]:
- return self.scan_ranges
- def get_size(self) -> Sequence[Expr]:
- return self.size
- def get_pointwise_size(self) -> Sequence[Expr]:
- return self.ranges
- def index_length(self) -> int:
- return len(self.ranges) + len(self.scan_ranges)
- def inner_fn_args(self) -> Sequence[Sequence[_IntLike]]:
- index = self._index(self.ranges)
- rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
- idx = self.reindex(index, rindex)
- return (idx,)
- def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
- index = self._index(self.ranges)
- rindex = self._index(self.scan_ranges, SymT.R0_INDEX)
- idx = self.reindex(index, rindex)
- return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only)
- @classmethod
- def create( # type: ignore[override]
- cls,
- device: torch.device,
- dtypes: tuple[torch.dtype, ...],
- inner_fns: tuple[Callable[[Sequence[Expr]], Any], ...],
- size: list[Integer],
- axis: int,
- combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
- reduction_hint: ReductionHint = ReductionHint.DEFAULT,
- *,
- # Whether we have the option to fallback to aten
- can_fallback_to_aten: bool = True,
- **kwargs: Any,
- ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]:
- pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
- scan_ranges = [size[axis]]
- if not V.graph.has_feature(device, BackendFeature.SCAN):
- return [None] * len(dtypes)
- if len(dtypes) > 1 and not V.graph.has_feature(
- device, BackendFeature.TUPLE_REDUCTION
- ):
- return [None] * len(dtypes)
- sizevars = V.graph.sizevars
- scan_numel = sizevars.simplify(sympy_product(scan_ranges))
- assert len(dtypes) == len(inner_fns)
- # Scan with a single element is just a copy
- if sizevars.statically_known_true(sympy.Le(scan_numel, 1)):
- return [
- Pointwise.create(
- device=device,
- dtype=dtypes[output_index],
- inner_fn=inner_fns[output_index],
- ranges=size,
- )
- for output_index in range(len(dtypes))
- ]
- reduction_hint, num_splits = cls.num_splits(
- device=device,
- dtype=dtypes[0],
- inner_fn=inner_fns[0],
- axis=axis,
- pointwise_ranges=pointwise_ranges,
- scan_ranges=scan_ranges,
- combine_fn=combine_fn,
- scan_numel=scan_numel,
- )
- scan_type = Scan
- if num_splits > 1:
- supports_split = (
- torch.version.hip is None or (has_triton and triton_version >= "3.3.0")
- ) and (len(dtypes) == 1)
- if not supports_split:
- if can_fallback_to_aten:
- # Fallback to ATen
- return [None] * len(dtypes)
- else:
- num_splits = 1
- else:
- scan_type = SplitScan
- def reindex(index: Sequence[Expr], scan_index: Sequence[Expr]) -> list[Expr]:
- assert len(scan_index) == len(scan_ranges)
- assert len(index) == len(pointwise_ranges)
- return [*index[:axis], *scan_index, *index[axis:]]
- results = [
- TensorBox.create(
- scan_type(
- device=device,
- dtype=dtypes[output_index],
- dtypes=dtypes,
- inner_fn=inner_fns[output_index],
- inner_fns=inner_fns,
- size=size,
- ranges=pointwise_ranges,
- scan_ranges=scan_ranges,
- combine_fn=combine_fn,
- reindex=reindex,
- reduction_hint=reduction_hint,
- output_index=output_index,
- **kwargs,
- )
- )
- for output_index in range(len(dtypes))
- ]
- for result in results:
- result.realize()
- return results
- @classmethod
- def num_splits(
- cls,
- device: torch.device,
- dtype: torch.dtype,
- inner_fn: Callable[[Sequence[Expr]], OpsValue],
- axis: int,
- pointwise_ranges: list[Integer],
- scan_ranges: list[Integer],
- combine_fn: Callable[[tuple[Any, ...], tuple[Any, ...]], tuple[Any, ...]],
- scan_numel: Expr,
- ) -> tuple[ReductionHint, _IntLike]:
- # TODO: custom splitting heuristic for scan
- def wrapper_fn(idx: Sequence[Expr], reduction_idx: Sequence[Expr]) -> OpsValue:
- return inner_fn([*idx[:axis], *reduction_idx, *idx[axis:]])
- return Reduction.num_splits(
- device=device,
- dst_dtype=dtype,
- src_dtype=dtype,
- inner_fn=wrapper_fn,
- ranges=pointwise_ranges,
- reduction_ranges=scan_ranges,
- reduction_type="scan",
- reduction_numel=scan_numel,
- )
- # This signifies a scan op that should go through TritonSplitScanKernel codegen on CUDA.
- @ir_dataclass
- class SplitScan(Scan):
- pass
- @ir_dataclass
- class Sort(Loops):
- # Sorts a tuple of key, value pairs
- sort_ranges: list[Integer]
- size: list[Integer]
- reindex: Callable[[Sequence[Expr], Sequence[Expr]], Sequence[Expr]]
- reduction_hint: ReductionHint
- output_index: int
- # output_index indexes the following tuples
- dtypes: tuple[torch.dtype, ...]
- inner_fns: tuple[Callable[..., Any], ...]
- stable: bool
- descending: bool
- # HACK we mimic reduction
- @cache_on_self_and_args("Sort")
- def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
- return (
- super().get_free_symbol_uses(unbacked_only)
- | OrderedSet().union(
- *(get_free_symbols(e, unbacked_only) for e in self.sort_ranges)
- )
- | OrderedSet().union(
- *(get_free_symbols(e, unbacked_only) for e in self.size)
- )
- )
- def __post_init__(self) -> None:
- assert len(self.ranges) + len(self.sort_ranges) == len(self.size)
- super().__post_init__()
- def store_reduction(
- self,
- output_name: Optional[str],
- indexer: Callable[[Sequence[Expr]], Expr],
- vars: Sequence[Expr],
- reduction_vars: Sequence[Expr],
- ) -> Any:
- idx = self.reindex(vars, reduction_vars)
- values = tuple(inner_fn(idx) for inner_fn in self.inner_fns)
- result = ops.sort(self.dtypes, values, self.stable, self.descending)
- return ops.store(
- output_name or "unnamed", indexer(idx), result[self.output_index]
- )
- def get_reduction_type(self) -> Optional[str]:
- return "sort"
- def get_reduction_size(self) -> Sequence[Expr]:
- return self.sort_ranges
- def get_size(self) -> Sequence[Expr]:
- return self.size
- def get_pointwise_size(self) -> Sequence[Expr]:
- return self.ranges
- def index_length(self) -> int:
- return len(self.ranges) + len(self.sort_ranges)
- def inner_fn_args(self) -> Sequence[Sequence[Expr]]:
- index = self._index(self.ranges)
- rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
- idx = self.reindex(index, rindex)
- return (idx,)
- def inner_fn_free_symbols(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
- index = self._index(self.ranges)
- rindex = self._index(self.sort_ranges, SymT.R0_INDEX)
- idx = self.reindex(index, rindex)
- return extract_free_symbols(self.inner_fn, idx, unbacked_only=unbacked_only)
- @classmethod
- def create( # type: ignore[override]
- cls,
- device: torch.device,
- dtypes: tuple[torch.dtype, ...],
- inner_fns: tuple[Callable[[list[Expr]], Any], ...],
- size: list[Integer],
- axis: int,
- stable: bool,
- descending: bool,
- reduction_hint: ReductionHint = ReductionHint.DEFAULT,
- **kwargs: Any,
- ) -> Sequence[Optional[Union[TensorBox, ShapeAsConstantBuffer]]]:
- pointwise_ranges = [*size[:axis], *size[axis + 1 :]]
- sort_ranges = [size[axis]]
- if not V.graph.has_feature(device, BackendFeature.SORT):
- return [None] * len(dtypes)
- sizevars = V.graph.sizevars
- sort_numel = sizevars.simplify(sympy_product(sort_ranges))
- # Heuristic, smallest rblock where triton usually outperforms aten.sort
- # It also isn't bandwidth bound so fusion is unlikely to help.
- max_rblock = 512
- is_persistent_kernel = (
- config.triton.persistent_reductions
- and sizevars.statically_known_true(sympy.Le(sort_numel, max_rblock))
- )
- if not is_persistent_kernel:
- # We only support persistent triton kernels
- return [None] * len(dtypes)
- assert len(dtypes) == len(inner_fns)
- # Sort with a single element is just a copy
- if sizevars.statically_known_true(sympy.Le(sort_numel, 1)):
- return [
- Pointwise.create(
- device=device,
- dtype=dtypes[output_index],
- inner_fn=inner_fns[output_index],
- ranges=size,
- )
- for output_index in range(len(dtypes))
- ]
- def reindex(index: Sequence[Expr], sort_index: Sequence[Expr]) -> list[Expr]:
- assert len(sort_index) == len(sort_ranges)
- assert len(index) == len(pointwise_ranges)
- return [*index[:axis], *sort_index, *index[axis:]]
- results = [
- TensorBox.create(
- Sort(
- device=device,
- dtype=dtypes[output_index],
- dtypes=dtypes,
- inner_fn=inner_fns[output_index],
- inner_fns=inner_fns,
- size=size,
- ranges=pointwise_ranges,
- sort_ranges=sort_ranges,
- reindex=reindex,
- reduction_hint=reduction_hint,
- output_index=output_index,
- stable=stable,
- descending=descending,
- **kwargs,
- )
- )
- for output_index in range(len(dtypes))
- ]
- for result in results:
- result.realize()
- return results
- def is_storage_and_layout(x: IRNode) -> bool:
- try:
- as_storage_and_layout(x, freeze=False)
- return True
- except NotImplementedError:
- return False
- def is_contiguous_storage_and_layout(x: IRNode) -> bool:
- try:
- _buffer, layout = as_storage_and_layout(x, freeze=False)
- # pad the stride here so we will NOT claim an tensor as contiguous
- # if a padding is gonna happen.
- if layout.should_pad_strides():
- layout.pad_strides()
- return layout.is_contiguous()
- except NotImplementedError:
- return False
- def as_storage_and_layout(
- x: IRNode,
- freeze: bool = True,
- want_contiguous: bool = False,
- stride_order: Optional[Sequence[Union[int, Integer]]] = None,
- allow_padding: bool = False,
- exact_strides: Optional[Sequence[Union[int, Integer]]] = None,
- ) -> tuple[StorageBox, Layout]:
- """
- Try to simplify x into a StorageBox and a Layout.
- allow_padding only affect how we apply stride_order. When allow_padding
- is True, we have the freedom to add padding when applying the stride_order.
- """
- if isinstance(x, TensorBox):
- return as_storage_and_layout(
- x.data,
- freeze=freeze,
- want_contiguous=want_contiguous,
- stride_order=stride_order,
- allow_padding=allow_padding,
- exact_strides=exact_strides,
- )
- if isinstance(x, StorageBox):
- _, layout = as_storage_and_layout(
- x.data,
- freeze=freeze,
- want_contiguous=want_contiguous,
- stride_order=stride_order,
- allow_padding=allow_padding,
- exact_strides=exact_strides,
- )
- return x, x.data.get_layout()
- if isinstance(x, Buffer):
- if freeze:
- if want_contiguous:
- x.freeze_layout()
- assert x.get_layout().is_contiguous()
- elif stride_order is not None:
- x.freeze_layout_with_stride_order(
- stride_order, allow_padding=allow_padding
- )
- elif exact_strides is not None:
- x.freeze_layout_with_exact_strides(
- exact_strides, allow_padding=allow_padding
- )
- else:
- x.decide_layout()
- return StorageBox(x), x.get_layout()
- if isinstance(x, ReinterpretView):
- # making the base of x contiguous or stride_ordered will not necessarily make
- # the ReinterpretView either, so don't pass along those arguments
- buffer, _ = as_storage_and_layout(
- x.data,
- freeze=freeze,
- )
- return buffer, x.layout
- raise NotImplementedError
- def is_stride_order_storage_and_layout(
- x: IRNode, stride_order: Sequence[Union[int, Integer]]
- ) -> bool:
- try:
- _buffer, layout = as_storage_and_layout(x, freeze=False)
- return layout.is_stride_ordered(stride_order)
- except NotImplementedError:
- return False
- def is_unaligned(node: IRNode) -> bool:
- if isinstance(node, (TensorBox, StorageBox)):
- return is_unaligned(node.data)
- if isinstance(node, ReinterpretView):
- layout = node.layout
- has_unaligned_layout = not V.graph.sizevars.statically_known_multiple_of(
- layout.offset * get_dtype_size(layout.dtype), GPU_ALIGN_BYTES
- )
- return is_unaligned(node.data) or has_unaligned_layout
- if isinstance(node, Buffer):
- return node.get_name() in V.graph.unaligned_buffers
- # assume to be aligned otherwise
- return False
- @ir_dataclass
- class BaseView(IRNode):
- data: IRNode
- @cache_on_self_and_args("BaseView")
- def get_free_symbol_uses(self, unbacked_only: bool = False) -> OrderedSet[Symbol]:
- return self.data.get_free_symbol_uses(unbacked_only)
- def make_reindexer(self) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
- raise NotImplementedError(f"make_reindexer NYI on {self}")
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- inner = self.data.make_indexer()
- reindex = self.make_reindexer()
- def indexer(idx: Sequence[Expr]) -> Expr:
- return inner(reindex(idx))
- return indexer
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- inner = self.data.make_loader()
- reindex = self.make_reindexer()
- def loader(idx: Sequence[Expr]) -> OpsValue:
- return inner(reindex(idx))
- return loader
- @property
- def dtype(self) -> torch.dtype:
- return self.data.get_dtype()
- def get_layout(self) -> Layout:
- return self.data.get_layout()
- def get_device(self) -> Optional[torch.device]:
- return self.data.get_device()
- def get_origin_node(self) -> Optional[torch.fx.Node]:
- return None
- def get_name(self) -> str:
- return self.data.get_name()
- def get_pointwise_size(self) -> Sequence[Expr]:
- return self.get_size()
- def mark_reuse(self, users: int) -> None:
- return self.data.mark_reuse(users)
- def has_exceeded_max_reads(self) -> bool:
- return self.data.has_exceeded_max_reads()
- def realize(self) -> Optional[str]:
- return self.data.realize()
- def realize_hint(self) -> None:
- self.data.realize_hint()
- def get_storage_numel(self) -> _IntLike:
- return self.data.get_storage_numel()
- def is_extern(self) -> bool:
- return self.data.is_extern()
- def is_module_buffer(self) -> bool:
- assert isinstance(self.data, BaseView), type(self.data)
- return self.data.is_module_buffer()
- def get_read_names(self) -> OrderedSet[str]:
- return self.data.get_read_names()
- def get_reads(self) -> OrderedSet[Dep]:
- with patch.object(FlexibleLayout, "allow_indexing", True):
- return extract_read_writes(
- self.make_loader(),
- self.get_size(),
- ).reads
- def unwrap_view(self) -> IRNode:
- x: IRNode = self
- while isinstance(x, BaseView):
- x = x.data
- return x
- def constant_to_device(self, device: torch.device) -> IRNode:
- """Move this to a given device. Requires that all reads are to constants."""
- loader = self.make_loader()
- loader = patch.object(ConstantBuffer, "override_device", device)(loader)
- return Pointwise(
- device=device,
- dtype=self.get_dtype(),
- inner_fn=loader,
- ranges=self.get_size(),
- )
- @ir_dataclass
- class ExpandView(BaseView):
- size: Sequence[Expr]
- @staticmethod
- def _normalize_size(x: IRNode, new_size: Sequence[_IntLike]) -> Sequence[_IntLike]:
- """Replace `-1` with correct sizes"""
- sizevars = V.graph.sizevars
- new_size = [sympy.expand(s) for s in new_size]
- old_size = x.get_size()
- old_size = [None] * (len(new_size) - len(old_size)) + list(old_size)
- assert len(new_size) == len(old_size)
- for i in range(len(new_size)):
- if new_size[i] == -1:
- assert old_size[i] is not None
- new_size[i] = old_size[i]
- elif old_size[i] is None or V.graph.sizevars.is_size_one_or_false(
- old_size[i]
- ):
- pass
- else:
- # Sanity check: Expect broadcast compatibility
- #
- # NB: new_size[i] == old_size[i] is expected to already be
- # guarded because the meta formula was expected to have taught
- # us this equality.
- assert sizevars.size_hint(new_size[i] - old_size[i], fallback=0) == 0, (
- "Broadcast failed in ExpandView({x.get_size()}, {new_size}) on dimension {i}"
- )
- return new_size
- @classmethod
- def create(cls, x: IRNode, new_size: Sequence[_IntLike]) -> BaseView:
- new_size = cls._normalize_size(x, new_size)
- if is_storage_and_layout(x):
- storage, old_layout = as_storage_and_layout(x)
- skip = len(new_size) - len(old_layout.size)
- assert skip >= 0
- new_stride = [sympy.S.Zero] * skip
- for stride, size in zip(old_layout.stride, old_layout.size):
- new_stride.append(
- stride
- if not V.graph.sizevars.is_size_one_or_false(size)
- else sympy.S.Zero
- )
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- list(new_size),
- new_stride,
- old_layout.offset,
- old_layout.is_pinned,
- )
- return ReinterpretView(data=storage, layout=new_layout)
- return ExpandView(data=x, size=new_size)
- def get_size(self) -> Sequence[Expr]:
- return self.size
- def make_reindexer(
- self,
- ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
- target = self.get_size()
- actual = self.data.get_size()
- skip = len(target) - len(actual)
- def reindex(
- index: Sequence[Expr],
- ) -> Sequence[Expr]:
- index = list(index[skip:])
- assert len(index) == len(actual)
- for i in range(len(actual)):
- if actual[i] == 1:
- # zero out broadcast dimension
- index[i] = sympy.S.Zero
- return index
- return reindex
- @ir_dataclass
- class PermuteView(BaseView):
- dims: list[Expr]
- @classmethod
- def create(cls, x: IRNode, dims: Sequence[int]) -> BaseView:
- dims = cls._map_neg_dims(dims)
- assert OrderedSet(dims) == OrderedSet(range(len(dims)))
- if is_storage_and_layout(x):
- storage, old_layout = as_storage_and_layout(x)
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- [old_layout.size[i] for i in dims],
- [old_layout.stride[i] for i in dims],
- old_layout.offset,
- old_layout.is_pinned,
- )
- return ReinterpretView(data=storage, layout=new_layout)
- return PermuteView(data=x, dims=dims)
- @classmethod
- def _map_neg_dims(cls, dims: Sequence[int]) -> list[int]:
- return [dim if dim >= 0 else len(dims) + dim for dim in dims]
- def get_size(self) -> Sequence[Expr]:
- assert OrderedSet(self._map_neg_dims(self.dims)) == OrderedSet(
- range(len(self.dims))
- )
- size = self.data.get_size()
- return [size[i] for i in self.dims]
- def make_reindexer(
- self,
- ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
- inv = {j: i for i, j in enumerate(self.dims)}
- inv = [inv[i] for i in range(len(self.dims))]
- assert OrderedSet(inv) == OrderedSet(range(len(self.dims)))
- def reindex(
- index: Sequence[Expr],
- ) -> Sequence[Expr]:
- return [index[i] for i in inv]
- return reindex
- @ir_dataclass
- class SqueezeView(BaseView):
- @classmethod
- def create(cls, x: IRNode, *, dim: Optional[int] = None) -> IRNode:
- if is_storage_and_layout(x):
- storage, old_layout = as_storage_and_layout(x)
- new_size = []
- new_stride = []
- if dim is not None:
- assert isinstance(dim, int), type(dim)
- assert 0 <= dim and dim < len(old_layout.size)
- for i, (size, stride) in enumerate(zip(old_layout.size, old_layout.stride)):
- if dim is None:
- if size != 1:
- new_size.append(size)
- new_stride.append(stride)
- else:
- if i != dim:
- new_size.append(size)
- new_stride.append(stride)
- else:
- assert size == 1, "expected squeezed size to be 1"
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- new_size,
- new_stride,
- old_layout.offset,
- old_layout.is_pinned,
- )
- return ReinterpretView(data=storage, layout=new_layout)
- if dim is None:
- # redirect to a generic view
- return View.create(x, [s for s in x.get_size() if s != 1])
- else:
- assert x.get_size()[dim] == 1
- return View.create(x, [s for i, s in enumerate(x.get_size()) if i != dim])
- @staticmethod
- def squeezer(
- size: Sequence[Expr],
- ) -> tuple[list[int], Callable[[Sequence[Expr]], tuple[Expr]]]:
- new_size = [s for s in size if s != 1]
- not_one = [i for i, s in enumerate(size) if s != 1]
- length = len(size)
- def reindex(index: Sequence[Expr]) -> tuple[Expr]:
- assert len(index) == len(not_one), f"{index} {not_one}"
- new_index = [sympy.S.Zero] * length
- for idx, s in zip(not_one, index):
- new_index[idx] = s
- return tuple(new_index)
- return new_size, reindex
- def __init__(self, data: Any) -> None:
- raise AssertionError("use SqueezeView.create()")
- @ir_dataclass
- class GenericView(BaseView):
- size: Sequence[Expr]
- reindex: Callable[[Sequence[Expr]], Sequence[Expr]]
- def make_reindexer(
- self,
- ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
- return self.reindex
- def reindex_str(self) -> str:
- index_old = [
- sympy_index_symbol_with_prefix(SymT.INDEX, n) for n in range(len(self.size))
- ]
- index_new = list(self.reindex(index_old))
- return f"lambda {', '.join(map(str, index_old))}: {index_new}"
- def __str__(self) -> str:
- return self.str_helper(
- [self.data, f"size={self.size}", f"reindex={self.reindex_str()}"]
- )
- __repr__ = __str__
- @classmethod
- def create(
- cls,
- x: IRNode,
- new_size: Sequence[Expr],
- reindex: Callable[[Sequence[Expr]], Sequence[Expr]],
- ) -> BaseView:
- return cls(data=x, size=list(new_size), reindex=reindex)
- def get_size(self) -> Sequence[Expr]:
- return self.size
- @ir_dataclass
- class View(GenericView):
- @staticmethod
- def handle_negative_index(idx: Expr, size: Expr) -> Expr:
- idx = sympy.expand(idx)
- size = sympy.expand(size)
- evaluate_expr = V.graph.sizevars.shape_env.evaluate_expr
- if evaluate_expr(sympy.Lt(idx, 0)):
- idx = idx + size
- return idx
- @classmethod
- def create(cls, x: IRNode, new_size: Sequence[Expr]) -> IRNode: # type: ignore[override]
- assert isinstance(new_size, Sequence), type(new_size)
- old_size, new_size = cls.resolve_negative_size(x.get_size(), new_size)
- # Skip pointless views
- if V.graph.sizevars.statically_known_list_equals(old_size, new_size):
- return x
- unbacked_symbols_in_sizes = False
- if (
- len(free_unbacked_symbols(old_size)) > 0
- or len(free_unbacked_symbols(new_size)) > 0
- ):
- unbacked_symbols_in_sizes = True
- if 0 in new_size:
- def fake_reindex(index: Any) -> tuple[int, ...]:
- return tuple([0] * len(old_size))
- return cls(data=x, size=list(new_size), reindex=fake_reindex)
- # TODO: a new class for FixedTransferLayout that output layout is constrained by input layout
- elif is_contiguous_storage_and_layout(x) or unbacked_symbols_in_sizes:
- if unbacked_symbols_in_sizes and (not is_contiguous_storage_and_layout(x)):
- # realize x; otherwise, the dynamic_reshape_indexer below will fail
- # due to the size_hint's inability to process unbacked SymInts
- # TODO: unbacked should not diverge from backed in determining striding
- # Need to require contiguous here instead of realize, see:
- # https://github.com/pytorch/pytorch/issues/145561
- x = ExternKernel.require_contiguous(x)
- storage, old_layout = as_storage_and_layout(x, want_contiguous=True)
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- new_size,
- FlexibleLayout.contiguous_strides(new_size),
- old_layout.offset,
- old_layout.is_pinned,
- )
- return ReinterpretView(data=storage, layout=new_layout)
- reindex = cls.dynamic_reshape_indexer(old_size, new_size)
- return cls(data=x, size=list(new_size), reindex=reindex)
- @staticmethod
- def resolve_negative_size(
- old_size: Sequence[Expr], new_size: Sequence[Expr]
- ) -> tuple[list[Expr], list[Expr]]:
- new_size = [V.graph.sizevars.simplify(x) for x in new_size]
- old_size = [V.graph.sizevars.simplify(x) for x in old_size]
- new_size = list(new_size)
- for i in range(len(new_size)):
- if new_size[i] == -1:
- new_size[i] = sympy.S.One
- new_size[i] = CleanDiv(sympy_product(old_size), sympy_product(new_size))
- break
- V.graph.sizevars.check_equals(sympy_product(old_size), sympy_product(new_size))
- return old_size, new_size
- @classmethod
- def dynamic_reshape_indexer(
- cls,
- old_size: Sequence[_IntLike],
- new_size: Sequence[_IntLike],
- dense_dim: Optional[int] = None,
- ) -> Callable[[Sequence[_T]], Sequence[_V]]:
- try:
- reindex = cls._dynamic_reshape_indexer(old_size, new_size, dense_dim)
- except (AssertionError, IndexError):
- # optimistic algorithm failed, lets do a fallback
- flat = [sympy_product(old_size)]
- reindex1 = cls._dynamic_reshape_indexer(old_size, flat)
- reindex2 = cls._dynamic_reshape_indexer(flat, new_size)
- reindex = fuse_reindexing(reindex1, reindex2)
- return reindex
- @staticmethod
- def _dynamic_reshape_indexer(
- old_size: Sequence[Expr],
- new_size: Sequence[Expr],
- dense_dim: Optional[int] = None,
- ) -> Callable[[Sequence[Expr]], Sequence[Expr]]:
- """
- Perform a reshape entirely by modifying indexing math
- """
- size_hint = V.graph.sizevars.size_hint
- # TODO: These symbols may not escape, if they don't assert so and
- # treat them as temporary
- vars = [
- sympy_index_symbol_with_prefix(SymT.VIEW, i) for i in range(len(new_size))
- ]
- stack_new = list(zip(vars, new_size))
- stack_old = list(old_size)
- # process the dense dim first
- reordering_dense_dim = (
- dense_dim is not None
- and dense_dim != len(stack_old) - 1
- and len(new_size) == 1
- )
- if reordering_dense_dim:
- assert dense_dim is not None # mypy
- old_dim = stack_old.pop(dense_dim)
- stack_old.append(old_dim)
- view_expr = []
- while stack_new and stack_old:
- size_old = stack_old.pop()
- var, size_new = stack_new.pop()
- if size_old == 1:
- view_expr.append(sympy.S.Zero)
- stack_new.append((var, size_new)) # re-add
- elif size_new == 1:
- stack_old.append(size_old) # re-add
- elif size_hint(size_new) == size_hint(size_old):
- view_expr.append(var)
- V.graph.sizevars.check_equals(size_new, size_old)
- elif size_hint(size_new) < size_hint(size_old):
- while size_hint(size_new) < size_hint(size_old):
- var2, size_new2 = stack_new.pop()
- var = var2 * size_new + var
- size_new = size_new * size_new2
- view_expr.append(var)
- V.graph.sizevars.check_equals(size_new, size_old)
- elif size_hint(size_new) > size_hint(size_old):
- divisor = sympy.S.One
- modulus = size_old
- view_expr.append(ModularIndexing(var, divisor, modulus))
- divisor = divisor * modulus
- while size_hint(size_new) > size_hint(size_old):
- modulus = stack_old.pop()
- view_expr.append(ModularIndexing(var, divisor, modulus))
- divisor = divisor * modulus
- size_old = size_old * modulus
- V.graph.sizevars.check_equals(size_new, size_old)
- else:
- raise AssertionError
- while stack_old:
- size_old = stack_old.pop()
- V.graph.sizevars.check_equals(size_old, 1)
- view_expr.append(sympy.S.Zero)
- while stack_new:
- var, size_new = stack_new.pop()
- V.graph.sizevars.check_equals(size_new, 1)
- if dense_dim is not None and len(new_size) == 1:
- view_expr.reverse()
- # Move the last expression (dense dim) to its original position
- dense_expr = view_expr.pop()
- view_expr.insert(dense_dim, dense_expr)
- else:
- view_expr.reverse()
- assert len(view_expr) == len(old_size)
- def reindex(
- index: Sequence[Expr],
- ) -> Sequence[Expr]:
- assert len(index) == len(vars), (len(index), len(vars))
- replacements = dict(zip(vars, index))
- return tuple(sympy_subs(x, replacements) for x in view_expr)
- return reindex
- @ir_dataclass
- class ReinterpretView(BaseView):
- """Pretend our storage has a different layout"""
- layout: Layout
- def __post_init__(self) -> None:
- super().__post_init__()
- if isinstance(self.data, BaseView):
- object.__setattr__(self, "data", self.data.unwrap_view())
- def __str__(self) -> str:
- return self.str_helper(
- [
- self.data,
- self.layout,
- ]
- )
- __repr__ = __str__
- def get_name(self) -> str:
- return self.data.get_name()
- def get_device(self) -> Optional[torch.device]:
- return self.layout.device
- def get_origin_node(self) -> Optional[torch.fx.Node]:
- return None
- @property
- def dtype(self) -> torch.dtype:
- return self.layout.dtype
- def get_size(self) -> Sequence[Expr]:
- return list(self.layout.size)
- def get_stride(self) -> Sequence[Expr]:
- return list(self.layout.stride)
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- def loader(index: Sequence[Expr]) -> OpsValue:
- indexer = self.layout.make_indexer()
- tmp_loader = ops.load(self.get_name(), indexer(index))
- if self.layout.dtype != self.data.dtype:
- return ops.to_dtype_bitcast(tmp_loader, self.dtype, self.data.dtype)
- else:
- return tmp_loader
- return loader
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- return self.layout.make_indexer()
- def get_layout(self) -> Layout:
- return self.layout
- def freeze_layout(self) -> None:
- pass
- @cache_on_self_and_args("ReinterpretView")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return (
- get_free_symbols(self.layout.size, unbacked_only)
- | get_free_symbols(self.layout.stride, unbacked_only)
- | get_free_symbols(self.layout.offset, unbacked_only)
- )
- def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
- # reinterpret_tensor is similar to as_strided except:
- # - offset is added to the existing offset (rather than replacing it)
- # - view tracking is disabled similar to unsafe_view
- return V.graph.wrapper_code.codegen_reinterpret_view(
- self.data,
- self.layout.size,
- self.layout.stride,
- self.layout.offset,
- writer.writeline if writer is not None else V.graph.wrapper_code.writeline,
- dtype=self.layout.dtype,
- )
- def num_reads(self) -> int:
- return 1
- @ir_dataclass
- class DtypeView(BaseView):
- """Pretend our storage has a different type"""
- target_dtype: torch.dtype
- @classmethod
- def create(cls, x: IRNode, new_dtype: torch.dtype) -> BaseView:
- if is_storage_and_layout(x):
- storage, old_layout = as_storage_and_layout(x)
- new_layout = FixedLayout(
- old_layout.device,
- new_dtype,
- old_layout.size,
- old_layout.stride,
- old_layout.offset,
- old_layout.is_pinned,
- )
- return ReinterpretView(data=storage, layout=new_layout)
- return DtypeView(data=x, target_dtype=new_dtype)
- def __str__(self) -> str:
- return self.str_helper([self.data, self.target_dtype])
- __repr__ = __str__
- @property
- def dtype(self) -> torch.dtype:
- return self.target_dtype
- def get_size(self) -> Sequence[Expr]:
- return self.data.get_size()
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- inner = self.data.make_loader()
- def loader(idx: Sequence[Expr]) -> OpsValue:
- return ops.to_dtype_bitcast(inner(idx), self.target_dtype, self.data.dtype)
- return loader
- class SliceView(View):
- @classmethod
- def normalize_start_end(
- cls, x: IRNode, dim: int, start: int, end: int
- ) -> tuple[int, int]:
- """
- Normalize start and end such that both are in the range
- [0, x.get_size()[dim]] and start <= end.
- """
- sizevars = V.graph.sizevars
- dim_size = x.get_size()[dim]
- if any(free_unbacked_symbols(x) for x in (start, end, dim_size)):
- min_func = sympy.Min
- max_func = sympy.Max
- else:
- min_func = sizevars.evaluate_min
- max_func = sizevars.evaluate_max
- def clamp(x: Expr, lower: int, upper: int) -> Expr:
- clamped_lower = (
- x if sizevars.statically_known_geq(x, lower) else max_func(x, lower)
- )
- clamped_full = (
- clamped_lower
- if sizevars.statically_known_leq(clamped_lower, upper)
- else min_func(clamped_lower, upper)
- )
- return clamped_full
- def clamp_wrap(
- val: Union[int, None], lower: int, upper: int, default: Union[Expr, int]
- ) -> Union[Expr, int]:
- if val is None:
- # TODO(rec): can this really happen?
- return default
- val = cls.handle_negative_index(val, dim_size)
- return clamp(val, lower, upper)
- start = clamp_wrap(start, 0, dim_size, 0)
- end = clamp_wrap(end, start, dim_size, dim_size)
- return start, end
- @classmethod
- def create( # type: ignore[override]
- cls,
- x: IRNode,
- dim: int,
- start: int,
- end: int,
- step: int = 1,
- clamp: bool = True,
- ) -> IRNode:
- step = sympy.expand(step)
- assert isinstance(step, Expr) or step > 0, step
- try:
- if start == 0 and end >= 2**63 - 1 and step == 1:
- return x
- except TypeError:
- pass
- new_size = list(x.get_size())
- # NB: Ordinarily we default to clamping.
- # We only don't clamp for split_with_sizes. For split_with_sizes, sizes should be already valid
- # failing in this situation is ok, since invalid sizes could trigger silent errors.
- if clamp:
- start, end = cls.normalize_start_end(x, dim, start, end)
- new_size[dim] = FloorDiv(end - start + (step - 1), step)
- if is_storage_and_layout(x):
- # Fast path
- storage, old_layout = as_storage_and_layout(x)
- new_stride = list(old_layout.stride)
- new_stride[dim] = new_stride[dim] * step
- new_layout = FixedLayout(
- old_layout.device,
- old_layout.dtype,
- new_size,
- new_stride,
- old_layout.offset + old_layout.stride[dim] * start,
- old_layout.is_pinned,
- )
- return ReinterpretView(data=storage, layout=new_layout)
- def reindex(
- index: Sequence[Expr],
- ) -> Sequence[Expr]:
- assert len(index) == len(new_size), f"wrong ndim {index} {new_size}"
- index = list(index)
- index[dim] = index[dim] * step + start
- return index
- # redirect to a generic view
- return SliceView(data=x, size=new_size, reindex=reindex)
- @ir_dataclass
- class BaseConstant(IRNode):
- dtype: torch.dtype
- device: torch.device
- def get_size(self) -> Sequence[Expr]:
- return ()
- def get_device(self) -> Optional[torch.device]:
- return self.device
- def get_origin_node(self) -> Optional[torch.fx.Node]:
- return None
- def get_reads(self) -> OrderedSet[Dep]:
- return OrderedSet()
- @ir_dataclass
- class Constant(BaseConstant):
- value: Any
- dtype: torch.dtype
- device: torch.device
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- def loader(index: Sequence[Expr]) -> OpsValue:
- return ops.constant(self.value, self.dtype)
- return loader
- def realize(self) -> Optional[str]:
- pass
- def constant_to_device(self, device: torch.device) -> IRNode:
- return Constant(value=self.value, dtype=self.dtype, device=device)
- @ir_dataclass
- class IndexingConstant(BaseConstant):
- index: Any
- dtype: torch.dtype
- device: torch.device
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- def loader(index: Sequence[Expr]) -> OpsValue:
- return ops.index_expr(self.index, self.dtype)
- return loader
- def constant_to_device(self, device: torch.device) -> IRNode:
- return IndexingConstant(index=self.index, dtype=self.dtype, device=device)
- def is_contiguous_strides_for_shape(
- stride: Sequence[_IntLike], shape: Sequence[_IntLike]
- ) -> bool:
- expected_stride = 1
- expected_stride_max = 1
- for x, y in reversed(tuple(zip(shape, stride))):
- if x == 1:
- continue
- if not V.graph.sizevars.statically_known_equals(
- y, expected_stride
- ) and not V.graph.sizevars.statically_known_equals(y, expected_stride_max):
- return False
- expected_stride_max *= sympy.Max(1, x)
- expected_stride *= x
- return True
- def get_align_for_dtype(dtype: torch.dtype) -> int:
- return config.padding_alignment_bytes // dtype.itemsize
- class OutputSpec:
- """Abstract base for Layout, MultiOutputLayout, NoneLayout.
- Represents the memory layout of the output of an Operation."""
- def get_device(self) -> Optional[torch.device]:
- raise NotImplementedError(type(self).__name__)
- def storage_size(self) -> int:
- raise NotImplementedError(type(self).__name__)
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- raise NotImplementedError(type(self).__name__)
- @ir_dataclass
- class Layout(OutputSpec):
- """
- Layout base class
- Carries tensor meta-information including offset and
- whether it is pinned.
- """
- def __init__(
- self,
- device: torch.device,
- dtype: torch.dtype,
- size: Sequence[Expr],
- stride: Optional[Sequence[Expr]] = None,
- offset: Expr = Integer(0),
- is_pinned: bool = False,
- ) -> None:
- if stride is None:
- stride = FlexibleLayout.contiguous_strides(size)
- self.device = device
- self.dtype = dtype
- assert len(size) == len(stride), f"size={size}, stride={stride}"
- assert all(isinstance(s, (Expr, int)) for s in size)
- self._size = size
- self._stride = stride
- self._offset = offset
- self.is_pinned = is_pinned
- # is_pinned implies cpu
- assert (not self.is_pinned) or (self.device.type == "cpu")
- @property
- def size(self) -> Sequence[Expr]:
- return self._size
- @size.setter
- def size(self, value: Sequence[Expr]) -> None:
- self._size = value
- @property
- def stride(self) -> Sequence[Expr]:
- return self._stride
- @stride.setter
- def stride(self, value: Sequence[Expr]) -> None:
- self._stride = value
- @property
- def offset(self) -> Expr:
- return self._offset
- @offset.setter
- def offset(self, value: Expr) -> None:
- self._offset = value
- def __str__(self) -> str:
- offset = ""
- if self.offset != 0:
- offset = f", offset={self.offset}"
- device_index_str = "" if self.device.index is None else f":{self.device.index}"
- is_pinned_str = ""
- if self.is_pinned:
- is_pinned_str = f", is_pinned={self.is_pinned}"
- return (
- f"{type(self).__name__}('{self.device.type}{device_index_str}', {self.dtype}, "
- f"size={self.size}, stride={self.stride}{offset}{is_pinned_str})"
- )
- __repr__ = __str__
- def get_device(self) -> torch.device:
- return self.device
- def get_example(self) -> torch.Tensor:
- with V.fake_mode:
- return torch.empty_strided(
- convert_shape_to_symint(self.size),
- convert_shape_to_symint(self.stride),
- dtype=self.dtype,
- device=self.device,
- pin_memory=self.is_pinned,
- )
- def is_contiguous(self) -> bool:
- return is_contiguous_strides_for_shape(self.stride, self.size)
- @staticmethod
- def is_channels_last_contiguous(
- shape: Sequence[_IntLike], strides: Sequence[_IntLike]
- ) -> bool:
- ndim = len(shape)
- if ndim not in [4, 5] or shape[1] == 1:
- return False
- for left, right, size in zip(
- strides, make_channels_last_strides_for(shape), shape
- ):
- if size != 1 and left != right:
- return False
- return True
- def is_transposed(self) -> bool:
- for left, right, size in zip(
- self.stride,
- reversed(FlexibleLayout.contiguous_strides(list(reversed(self.size)))),
- self.size,
- ):
- if size != 1 and left != right:
- return False
- return True
- def is_stride_ordered(self, order: Sequence[int]) -> bool:
- assert len(self.stride) == len(order)
- # ignore dimensions of size 1, they dont affect layout
- non_1_indices = [
- i
- for i, dim in enumerate(self.size)
- if V.graph.sizevars.size_hint(dim, fallback=2) != 1
- ]
- stride = [self.stride[i] for i in non_1_indices]
- order: Sequence[int] = [order[i] for i in non_1_indices]
- def sorted_indices(arr: Sequence[int]) -> Sequence[int]:
- sorted_arr = sorted(arr)
- return [sorted_arr.index(element) for element in arr]
- # since we may have removed dimensions, need to re-sort & re-index order
- order = sorted_indices(order)
- # reorder the stride given order
- stride_ordered = [-1] * len(order)
- for i in range(len(order)):
- stride_ordered[order[i]] = stride[i]
- # check if it is in ascending order
- for i in range(len(order) - 1):
- expr = stride_ordered[i] > stride_ordered[i + 1]
- if not isinstance(expr, bool):
- expr = V.graph._shape_env.evaluate_expr(
- stride_ordered[i] > stride_ordered[i + 1], size_oblivious=True
- )
- if expr:
- return False
- return True
- def is_channels_last_stride_ordered(self) -> bool:
- # create channels_last order(NCHW, NCDHW, the C is the first order).
- order = [0] + list(reversed(range(1, len(self.stride) - 1)))
- order = [len(order)] + order
- return self.is_stride_ordered(order)
- @staticmethod
- def _pad_strides(
- in_strides: Sequence[int], size: Sequence[Expr], dtype: torch.dtype
- ) -> Sequence[int]:
- """
- The padding does not change stride order but makes sure all strides larger
- than the threshold are multiple of align.
- """
- align = get_align_for_dtype(dtype)
- if len(in_strides) == 0:
- return in_strides
- if not config.pad_channels_last and Layout.is_channels_last_contiguous(
- size, in_strides
- ):
- return in_strides
- current_fx_node = V.get_current_node()
- if hasattr(current_fx_node, "meta") and current_fx_node.meta.get(
- "dislike_padding", False
- ):
- return in_strides
- shape_env = V.graph._shape_env if hasattr(V.graph, "_shape_env") else None
- def contains_unbacked_symints(expr: sympy.Expr | int) -> bool:
- if shape_env is None:
- return False
- if not isinstance(expr, sympy.Expr):
- return False
- return any(shape_env.is_unbacked_symint(s) for s in expr.free_symbols)
- # Skip padding the strides when it contains unbacked symints for now.
- if shape_env and any(contains_unbacked_symints(s) for s in in_strides):
- return in_strides
- stride_order = get_stride_order(in_strides, shape_env)
- fill_order = stride_order2fill_order(stride_order)
- new_strides = [0 for _ in range(len(in_strides))]
- # since we pad when the layout is flexible, we can decide the
- # smallest stride to be 1.
- new_strides[fill_order[0]] = 1
- padded = False
- for rank, idx in enumerate(fill_order[1:], start=1):
- prev_idx = fill_order[rank - 1]
- stride = new_strides[prev_idx] * size[prev_idx]
- # Static stride and meets padding conditions OR
- # Dynamic stride and config.pad_dynamic_shape=True
- require_padding = (
- isinstance(stride, (int, sympy.Integer))
- and stride > config.padding_stride_threshold
- and stride % align != 0
- ) or (isinstance(stride, sympy.Expr) and config.pad_dynamic_shapes)
- new_strides[idx] = stride
- if require_padding:
- new_strides[idx] = ceildiv(stride, align) * align
- padded = True
- if not padded:
- # Consider a tensor with shape [256, 1, 5, 5]
- # Avoid strides like [25, 5, 5, 1] being padded to equivalent strides
- # [25, 25, 5, 1].
- return in_strides
- metrics.num_comprehensive_padding += 1
- return new_strides
- def pad_strides(self) -> None:
- assert isinstance(self, FlexibleLayout), type(self)
- assert self.stride is not None
- self.stride = self._pad_strides(self.stride, self.size, self.dtype)
- def should_pad_strides(self) -> bool:
- return config.comprehensive_padding and isinstance(self, FlexibleLayout)
- def as_fixed(self) -> FixedLayout:
- if isinstance(self, FixedLayout):
- return self
- if self.should_pad_strides():
- self.pad_strides()
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- self.stride,
- self.offset,
- self.is_pinned,
- )
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- assert FlexibleLayout.allow_indexing, (
- f"convert {type(self).__name__} to FixedLayout first"
- )
- return self.as_fixed().make_indexer()
- def __eq__(self, other: object) -> bool:
- return (
- isinstance(other, Layout)
- and self.device == other.device
- and self.dtype == other.dtype
- and self.size == other.size
- and self.stride == other.stride
- and self.offset == other.offset
- and self.is_pinned == other.is_pinned
- )
- def storage_size(self) -> Expr:
- return compute_required_storage_length(self.size, self.stride, self.offset) # type: ignore[arg-type]
- @cache_on_self_and_args("Layout")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return (
- get_free_symbols(self.size, unbacked_only)
- | get_free_symbols(self.stride, unbacked_only)
- | get_free_symbols(self.offset, unbacked_only)
- )
- class FixedLayout(Layout):
- """A Tensor layout we cannot change"""
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- """A closure containing math to read a given element"""
- return _fixed_indexer(self.size, self.stride, self.offset)
- class FlexibleLayout(Layout):
- """
- A Tensor layout that we are allowed to change
- Assumption: layout change should NOT add or remove free symbols
- """
- allow_indexing = False
- # WARNING! This doesn't handle zero size tensors correctly
- @staticmethod
- def contiguous_strides(sizes: Sequence[int]) -> list[Expr]:
- if len(sizes) == 0:
- return []
- reversed_strides = [sympy.S.One]
- for size in reversed(sizes[1:]):
- reversed_strides.append(size * reversed_strides[-1])
- return list(reversed(reversed_strides))
- @staticmethod
- def fill_ordered(sizes: Sequence[int], order: Sequence[int]) -> list[Expr]:
- """
- Create a stride based on the order the dimensions should be filled in.
- In this format, channels last would be:
- [1, 3, 2, 0]
- """
- assert OrderedSet(range(len(sizes))) == OrderedSet(order), (sizes, order)
- next_stride = sympy.S.One
- strides = [None] * len(order)
- for i in order:
- strides[i] = next_stride
- next_stride = next_stride * sizes[i]
- return strides
- @staticmethod
- def stride_ordered(sizes: Sequence[int], order: Sequence[int]) -> Sequence[Expr]:
- """
- Create a stride based on the sorted order of a permuted range.
- In this format, channels last would be:
- [3, 0, 2, 1]
- """
- assert OrderedSet(range(len(sizes))) == OrderedSet(order)
- fill_order = stride_order2fill_order(order)
- return FlexibleLayout.fill_ordered(sizes, fill_order)
- @staticmethod
- def stride_ordered_for_memory_format(
- sizes: Sequence[int], memory_format: torch.memory_format
- ) -> Sequence[Expr]:
- """
- Create a stride based on a memory format.
- Memory format is translasted into a stride order,
- so channels_last is the same as:
- FlexibleLayout.stride_ordered(sizes, [3, 0, 2, 1])
- This interface does not support memory_format `torch.preserve_format`
- which should be used to deduce a format from another source
- """
- if memory_format == torch.channels_last:
- return FlexibleLayout.stride_ordered(sizes, NHWC_STRIDE_ORDER)
- elif memory_format == torch.channels_last_3d:
- return FlexibleLayout.stride_ordered(sizes, NHWDC_STRIDE_ORDER)
- elif memory_format == torch.contiguous_format:
- return FlexibleLayout.contiguous_strides(sizes)
- else:
- log.debug(
- "stride_ordered_for_memory_format, unsuppored memory_format: %s",
- memory_format,
- )
- raise NotImplementedError
- @staticmethod
- def same_ordered(
- sizes: Sequence[int], stride: Sequence[_IntLike]
- ) -> Sequence[Expr]:
- """
- Create a stride that has the same stride order as given stride
- For example, if given stride is [1000, 1, 100, 10],
- the fill order should be [1, 3, 2, 0]
- """
- assert len(sizes) == len(stride)
- stride = [V.graph.sizevars.size_hint_or_throw(x) for x in stride]
- fill_order = sorted(range(len(stride)), key=stride.__getitem__)
- return FlexibleLayout.fill_ordered(sizes, fill_order)
- @property
- def size(self) -> Sequence[Expr]:
- return self._size
- @size.setter
- def size(self, value: Sequence[Expr]) -> None:
- self.assert_free_symbol_uses_unchanged("size", value)
- self._size = value
- @property
- def stride(self) -> Sequence[Expr]:
- return self._stride
- @stride.setter
- def stride(self, value: Sequence[Expr]) -> None:
- self.assert_free_symbol_uses_unchanged("stride", value)
- self._stride = value
- @property
- def offset(self) -> Expr:
- return self._offset
- @offset.setter
- def offset(self, value: Expr) -> None:
- self.assert_free_symbol_uses_unchanged("offset", value)
- self._offset = value
- def as_stride_order(
- self, order: Sequence[int], allow_padding: bool = False
- ) -> FixedLayout:
- new_stride = self.stride_ordered(self.size, order)
- if self.should_pad_strides() and allow_padding:
- new_stride = self._pad_strides(new_stride, self.size, self.dtype)
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- new_stride,
- self.offset,
- self.is_pinned,
- )
- def as_exact_strides(
- self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
- ) -> FixedLayout:
- new_stride = exact_strides
- if self.should_pad_strides() and allow_padding:
- new_stride = self._pad_strides(new_stride, self.size, self.dtype)
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- new_stride,
- self.offset,
- self.is_pinned,
- )
- def as_fill_order(self, order: Sequence[int]) -> FixedLayout:
- new_stride: Sequence[int] = self.fill_ordered(self.size, order)
- if self.should_pad_strides():
- new_stride = self._pad_strides(new_stride, self.size, self.dtype)
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- new_stride,
- self.offset,
- self.is_pinned,
- )
- def as_same_order(self, stride: Sequence[_IntLike]) -> FixedLayout:
- new_stride = self.same_ordered(self.size, stride)
- if self.should_pad_strides():
- new_stride = self._pad_strides(new_stride, self.size, self.dtype)
- return FixedLayout(
- self.device,
- self.dtype,
- self.size,
- new_stride,
- self.offset,
- self.is_pinned,
- )
- def get_initial_free_symbol_uses(self) -> dict[tuple[str, bool], sympy.Symbol]:
- initial_free_symbols = {}
- for name in ["size", "stride", "offset"]:
- for unbacked_only in [True, False]:
- key = (name, unbacked_only)
- initial_free_symbols[key] = OrderedSet(
- get_free_symbols(getattr(self, name), unbacked_only)
- )
- return initial_free_symbols
- def assert_free_symbol_uses_unchanged(self, name: str, value: IterateExprs) -> None:
- for unbacked_only in [True, False]:
- old_free_symbols = self.initial_free_symbols[(name, unbacked_only)]
- new_free_symbols = OrderedSet(get_free_symbols(value, unbacked_only))
- assert new_free_symbols == old_free_symbols, (
- f"Expected free symbols unchanged, but got {new_free_symbols} vs {old_free_symbols}"
- )
- def __init__(
- self,
- device: torch.device,
- dtype: torch.dtype,
- size: Sequence[Expr],
- stride_order: Optional[Sequence[Union[int, Integer]]] = None,
- is_pinned: bool = False,
- ) -> None:
- if stride_order:
- strides = FlexibleLayout.fill_ordered(size, stride_order)
- else:
- strides = FlexibleLayout.contiguous_strides(size)
- super().__init__(device, dtype, size, strides, is_pinned=is_pinned)
- # record the initial free symbols to check that we do not add new free symbols
- # later when modifying sizes, strides, and offsets.
- self.initial_free_symbols = self.get_initial_free_symbol_uses()
- class NonOwningLayout(Layout):
- """Is a view into the storage of another tensor"""
- def __init__(self, view: Union[BaseView, TensorBox]) -> None:
- layout = view.get_layout()
- super().__init__(
- layout.device,
- layout.dtype,
- layout.size,
- layout.stride,
- )
- self.view = view
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- return self.as_fixed().make_indexer()
- def maybe_guard_aligned(self) -> bool:
- offset = self.view.get_layout().offset
- if offset == 0:
- return True
- from .utils import ALIGNMENT
- return V.graph.sizevars.statically_known_multiple_of(offset, ALIGNMENT)
- @cache_on_self_and_args("NonOwningLayout")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- assert isinstance(self.view, ReinterpretView)
- box = self.view.data
- assert isinstance(box, StorageBox), type(box)
- input_buffer = box.data
- assert isinstance(input_buffer, Buffer), type(box)
- return input_buffer.layout.get_free_symbol_uses(unbacked_only)
- class CommBufferType(Enum):
- SYMM_MEM = "symm_mem"
- class CommBufferLayout(FixedLayout):
- """
- A layout that signifies the buffer is a comm buffer.
- In terms of striding, the layout is identical to `FixedLayout`.
- Buffers with this layout do not participate in in-place reuse - it can be
- neither the source nor the target for in-place reuse.
- For detailed motivation and usage of this layout, see
- NOTE [lowering-time collective optimization].
- """
- comm_buffer_type: CommBufferType
- group_name: str
- def __init__(
- self,
- layout: FlexibleLayout,
- comm_buffer_type: CommBufferType,
- group_name: str,
- ):
- if not isinstance(layout, FlexibleLayout):
- raise AssertionError(
- "A `CommBufferLayout` can only be initialized with "
- f"a `FlexibleLayout` (got {layout})."
- )
- fixed = layout.as_fixed()
- super().__init__(
- device=fixed.device,
- dtype=fixed.dtype,
- size=fixed.size,
- stride=fixed.stride,
- offset=fixed.offset,
- is_pinned=fixed.is_pinned,
- )
- self.comm_buffer_type = comm_buffer_type
- self.group_name = group_name
- @ir_dataclass
- class NoneLayout(OutputSpec):
- # This is janky, I figured out what fields to populate by just running
- # the model I was interested in and adding properties/methods as needed.
- # This doesn't inherit from Layout because Layout assumes you have stuff
- # like sizes, but I don't really have anything here.
- #
- # If you have an ir.Node with NoneLayout, you probably need to setup
- # dependencies manually in scheduler
- device: Optional[torch.device]
- size: list[int] = dataclasses.field(default_factory=lambda: [0])
- stride: list[int] = dataclasses.field(default_factory=lambda: [0])
- def storage_size(self) -> int:
- return 0
- def as_fixed(self) -> OutputSpec:
- return self
- def get_device(self) -> Optional[torch.device]:
- return self.device
- class MutationLayoutSHOULDREMOVE(Layout):
- def __init__(self, target: IRNode) -> None:
- super().__init__(
- target.get_device_or_error(),
- target.get_dtype(),
- target.get_size(),
- None,
- )
- self.target = target
- name = self.get_buffer().get_name()
- V.graph.mark_buffer_mutated(name)
- @property
- def stride(self) -> Sequence[Expr]: # type: ignore[override]
- return self.real_layout().stride
- @stride.setter # type: ignore[override]
- def stride(self, value: Never) -> None:
- pass # ignore setting of stride
- def storage_size(self) -> Expr:
- return self.real_layout().storage_size()
- def get_buffer(self) -> Buffer:
- def unwrap_views(target: Any) -> Any:
- if isinstance(target, MutationLayoutSHOULDREMOVE):
- return unwrap_views(target.target)
- if isinstance(target, BaseView):
- return unwrap_views(target.unwrap_view())
- if isinstance(target, MutableBox):
- return unwrap_views(target.data)
- return target
- result = unwrap_views(self.target)
- assert isinstance(result, Buffer), type(result)
- return result
- def real_layout(self) -> Layout:
- layout = self.get_buffer().layout
- assert isinstance(layout, Layout)
- return layout
- @classmethod
- def realize_into(
- cls, src: IRNode, dst: IRNode, unsafe_alias: bool = False
- ) -> IRNode:
- dst.realize()
- # NOTE: We must realize users of `dst` before we realize `src`, since
- # realization order determines scheduling order. Otherwise, src's
- # mutation would be scheduled before the existing users of dst!
- V.graph.mark_buffer_mutated(dst.get_name())
- if isinstance(src, TensorBox):
- src = src.data
- # We copy the contents of src into dst. In most cases this should
- # be fused into a single kernel by the scheduler.
- # NOTE: We cannot change src's layout to mutate dst directly as this
- # would alias src to dst, which is not correct as further mutations to
- # dst would effect users of src. However if there are no more users of
- # dst, we can alias src to dst.
- src.realize_hint()
- if not unsafe_alias:
- node = Pointwise.create(
- device=src.get_device(),
- dtype=src.get_dtype(),
- inner_fn=src.make_loader(),
- ranges=[
- V.graph.sizevars.check_equals_and_simplify(a, b)
- for a, b in zip(src.get_size(), dst.get_size())
- ],
- )
- assert isinstance(node, (BaseView, MutableBox))
- src = node.data
- src.realize()
- assert hasattr(src, "data"), src
- assert isinstance(src.data.layout, FlexibleLayout), type(src.data.layout)
- src.data.layout = MutationLayoutSHOULDREMOVE(dst)
- return src.data
- def as_fixed(self) -> Self: # type: ignore[override]
- return self
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- return self.target.make_indexer()
- @ir_dataclass(frozen=False)
- class Buffer(IRNode, CodegenSymbol):
- # Name is sometimes None; e.g., ForceInPlace, where there isn't
- # a meaningful name
- name: Optional[str]
- layout: OutputSpec
- # Multi-output buffers will define 'outputs: List[Buffer]'. Confusingly,
- # MultiOutput does NOT define this!
- def __post_init__(self) -> None:
- super().__post_init__()
- self._post_init_setattr("origin_node", None)
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- return self.get_layout().make_indexer()
- def get_name(self) -> str:
- assert self.name, self
- return self.name
- def get_example(self) -> Union[torch.Tensor, sympy.Symbol]:
- if isinstance(self.layout, Layout):
- return self.layout.get_example()
- raise NotImplementedError(type(self.layout).__name__)
- def get_device(self) -> Optional[torch.device]:
- return self.get_output_spec().get_device()
- def get_defining_op(self) -> Optional[Operation]:
- return None
- @property
- def dtype(self) -> torch.dtype:
- return self.get_layout().dtype
- def get_size(self) -> Sequence[Expr]:
- return [*self.get_layout().size]
- def get_stride(self) -> list[Expr]:
- return [*self.get_layout().stride]
- def get_offset(self) -> Expr:
- return self.get_layout().offset
- def get_layout(self) -> Layout:
- if isinstance(self.layout, Layout):
- return self.layout
- raise NotImplementedError(type(self.layout).__name__)
- def get_output_spec(self) -> OutputSpec:
- return self.layout
- def get_storage_numel(self) -> int:
- return self.get_numel()
- def get_is_pinned(self) -> bool:
- return self.get_layout().is_pinned
- def freeze_layout(self) -> None:
- if isinstance(self.layout, Layout) and not isinstance(
- self.layout, NonOwningLayout
- ):
- self.layout = self.layout.as_fixed()
- def freeze_layout_with_stride_order(
- self, order: Sequence[int], allow_padding: bool = False
- ) -> None:
- assert isinstance(self.layout, FlexibleLayout), type(self.layout)
- self.layout = self.layout.as_stride_order(order, allow_padding=allow_padding)
- def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
- assert isinstance(self.layout, FlexibleLayout), type(self.layout)
- self.layout = self.layout.as_fill_order(order)
- def freeze_layout_with_same_order(self, stride: Sequence[int]) -> None:
- assert isinstance(self.layout, FlexibleLayout), type(self.layout)
- self.layout = self.layout.as_same_order(stride)
- def freeze_layout_with_exact_strides(
- self, exact_strides: Sequence[int], allow_padding: bool = False
- ) -> None:
- assert isinstance(self.layout, FlexibleLayout), type(self.layout)
- self.layout = self.layout.as_exact_strides(
- exact_strides, allow_padding=allow_padding
- )
- def is_zero_elements(self) -> bool:
- return V.graph.sizevars.statically_known_true(sympy.Eq(self.get_numel(), 0))
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- # Loading from a zero-element buffer is a no-op
- if self.is_zero_elements():
- return partial(nop_loader_fn, dtype=self.get_dtype())
- def loader(index: Sequence[Expr]) -> OpsValue:
- indexer = self.make_indexer()
- return ops.load(self.name or "unnamed", indexer(index))
- return loader
- def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
- return self.get_name()
- def decide_layout(self) -> None:
- pass
- def get_inputs_that_alias_output(self) -> Sequence[str]:
- if isinstance(self.layout, NonOwningLayout):
- return [self.layout.view.get_name()]
- return ()
- def get_mutation_names(self) -> Sequence[str]:
- if isinstance(self.layout, MutationLayoutSHOULDREMOVE):
- return [self.layout.target.get_name()]
- return ()
- def get_read_names(self) -> OrderedSet[str]:
- return OrderedSet([self.get_name()])
- @cache_on_self_and_args("Buffer")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def realize(self) -> Optional[str]:
- pass
- def should_allocate(self) -> bool:
- # Returns False by default.
- return False
- @ir_dataclass(frozen=False)
- class OperationBuffer(Buffer, Operation):
- # An operation that produces a single output buffer
- def get_outputs(self) -> list[Buffer]:
- return [self]
- def get_defining_op(self) -> Operation:
- return self
- # Skip implementation in Buffer
- get_operation_name = Operation.get_operation_name
- def __post_init__(self) -> None:
- Buffer.__post_init__(self)
- Operation.__post_init__(self)
- class InputBuffer(Buffer):
- def num_reads(self) -> int:
- return 1
- class DonatedBuffer(InputBuffer):
- """
- Represents a donated buffer which is a saved tensor that is not alias to any
- fwd inputs, fwd user outputs, and bwd outputs. We generally cannot inplace
- reuse the input tensor memory during backward since it might be used in another
- function. However, donated buffer can be inplace reused during backward
- to save memory.
- """
- class ConstantBuffer(InputBuffer):
- override_device: Optional[torch.device] = None
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- def loader(index: Sequence[Expr]) -> OpsValue:
- indexer = self.get_layout().make_indexer()
- return ops.load(
- V.graph.constant_name(self.get_name(), self.override_device),
- indexer(index),
- )
- return loader
- def constant_to_device(self, device: torch.device) -> IRNode:
- return ConstantBuffer(
- name=V.graph.constant_name(self.get_name(), device), layout=self.layout
- )
- @ir_dataclass
- class NoneAsConstantBuffer(IRNode):
- def get_reads(self) -> OrderedSet[Dep]:
- return OrderedSet()
- @cache_on_self_and_args("NoneAsConstantBuffer")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
- return V.graph.wrapper_code.none_str
- def get_output_spec(self) -> OutputSpec:
- return NoneLayout(device=None)
- def has_tensor_output(self) -> bool:
- return False
- @ir_dataclass
- class ShapeAsConstantBuffer(IRNode):
- expr: Expr
- @cache_on_self_and_args("ShapeAsConstantBuffer")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return get_free_symbols(self.expr, unbacked_only)
- def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
- return V.graph.wrapper_code.codegen_sizevar(self.expr)
- def has_tensor_output(self) -> bool:
- return False
- @ir_dataclass(frozen=False)
- class ComputedBuffer(OperationBuffer):
- """
- Represents a buffer that is computed during kernel execution rather than being an input.
- """
- data: Loops
- _force_realize: ClassVar[bool] = False
- @staticmethod
- @contextlib.contextmanager
- def force_realize() -> Iterator[None]:
- old_value = ComputedBuffer._force_realize
- try:
- ComputedBuffer._force_realize = True
- yield
- finally:
- ComputedBuffer._force_realize = old_value
- def get_computed_buffer_name(self) -> Optional[str]:
- """
- Returns self.name if it exists, otherwise returns the name of the data node if that exists.
- If neither exist, returns None.
- """
- if self.name is not None:
- return self.name
- if hasattr(self.data, "name"):
- return self.data.name
- return None
- def num_reads(self) -> int:
- return self.data.num_reads()
- def get_reads(self) -> OrderedSet[Dep]:
- return self.data.get_reads()
- def get_read_names(self) -> OrderedSet[str]:
- return self.data.get_read_names()
- def get_read_writes(self) -> dependencies.ReadWrites:
- if not isinstance(self.data, (Reduction, Scan, Sort, Pointwise)):
- return dependencies.ReadWrites(
- reads=OrderedSet(),
- writes=OrderedSet(),
- index_exprs=OrderedSet(),
- )
- with patch.object(FlexibleLayout, "allow_indexing", True):
- if self.data.get_reduction_type():
- return extract_read_writes(
- self.get_store_function(),
- self.data.get_pointwise_size(),
- self.data.get_reduction_size(),
- )
- else:
- return extract_read_writes(
- self.get_store_function(),
- self.data.get_size(),
- )
- @cache_on_self_and_args("ComputedBuffer")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- # Ordinarily, we'd like to just peek at the arguments list,
- # but ComputedBuffers have no argument list.
- #
- # Morally, this logic needs to be synchronized with the
- # KernelArgs.size calls, which are responsible for making symbols make
- # there way as kernel arguments (and it is precisely passing in one of
- # those symbols that establishes a dependency). However, we haven't
- # started codegen yet so we can't directly reuse that logic.
- #
- # One thing you might wonder is if this is enough for a ComputedBuffer
- # denoting a reduction over i0. Empirically, it is enough, but for an
- # unusual reason: we only need accurate dependencies for item() call,
- # but it's impossible to end up with a reduction over i0 from an
- # item() call without a regular non-reduction buffer first.
- result = self.layout.get_free_symbol_uses(
- unbacked_only
- ) | self.data.get_free_symbol_uses(unbacked_only)
- if self.has_store_function() and isinstance(
- self.get_store_function(), LoopBody
- ):
- result |= self.get_read_writes().get_free_symbol_uses(unbacked_only)
- return result
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- if (
- not self.get_reduction_type()
- and self.name not in V.graph.mutated_buffers
- and self.num_reads() == 0
- and not self._force_realize
- ):
- # inline this op rather than generating ops.load()
- return self.data.make_loader()
- return super().make_loader()
- def has_store_function(self) -> bool:
- return isinstance(self.data, (Reduction, Scan, Sort, Pointwise))
- def get_store_function(self) -> Callable[..., None]:
- indexer = self.get_layout().as_fixed().make_indexer()
- if isinstance(self.data, (Reduction, Scan, Sort)):
- return partial(self.data.store_reduction, self.name, indexer)
- else:
- assert isinstance(self.data, Pointwise), type(self.data)
- return partial(self.data.store_output, self.name, indexer)
- def get_fill_order(self) -> Optional[list[int]]:
- """
- If our layout is still flexible, try to determine the stride order based on stride orders of reads.
- TODO(jansel): A better algorithm here would look at downstream consumers of this
- value and try to do global graph-level layout optimization.
- This is also something just begging to be autotuned.
- """
- if isinstance(self.layout, FlexibleLayout):
- (index_vars, reduction_vars), _ = dependencies.index_vars_squeeze(
- self.data.get_pointwise_size(), self.data.get_reduction_size()
- )
- reads = self.get_read_writes().reads
- # only consider reads to buffer of same size
- # ignore StarDeps because they don't contribute stride information
- assert all(
- isinstance(r, (dependencies.StarDep, dependencies.MemoryDep))
- for r in reads
- )
- reads = [
- sympy_subs(r.index, {v: sympy.S.Zero for v in reduction_vars if v != 0})
- for r in reads
- if isinstance(r, dependencies.MemoryDep)
- ]
- if reads:
- if isinstance(self.data, (Scan, Sort)):
- indices = self.data.reindex(index_vars, reduction_vars)
- else:
- indices = index_vars
- stride_lengths = [
- V.graph.sizevars.stride_hints(expr, indices) for expr in reads
- ]
- from .scheduler import pick_loop_order
- return pick_loop_order(stride_lengths, self.get_size())
- return None
- def decide_layout(self) -> None:
- if isinstance(self.layout, FlexibleLayout):
- order = self.get_fill_order()
- if order:
- self.freeze_layout_with_fill_order(order)
- else:
- self.freeze_layout()
- @cache_on_self
- def get_default_sizes_body(
- self,
- ) -> tuple[
- tuple[list[Expr], list[Expr]],
- LoopBody,
- tuple[list[Expr], list[Expr]],
- ]:
- args, var_ranges = dependencies.index_vars_squeeze(
- self.data.get_pointwise_size(), self.data.get_reduction_size(), prefix="q"
- )
- with patch.object(ConstantBuffer, "override_device", self.get_device()):
- body = LoopBody(
- self.get_store_function(),
- (args if self.get_reduction_type() else args[:1]),
- var_ranges,
- *args,
- )
- index_vars = []
- reduce_vars: list[Any] = []
- index_size = []
- reduce_size = []
- for v, s in var_ranges.items():
- if v in args[0]:
- assert not reduce_vars
- index_vars.append(v)
- index_size.append(s)
- else:
- assert v in args[1]
- reduce_vars.append(v)
- reduce_size.append(s)
- return (index_size, reduce_size), body, (index_vars, reduce_vars)
- def simplify_and_reorder(
- self,
- extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
- recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
- ) -> tuple[tuple[list[Expr], list[Expr]], Optional[LoopBody]]:
- """
- This is a main place where we do loop transformations in a
- backend-agnostic way.
- Here we:
- 1) Remove any 1 dimensions
- 2) Fuse contiguous dimensions together
- 3) Reorder dimensions based on stride orders
- Optional argument extra_indexing_constraints can be used to append additional
- indexing expressions to existing ones derived from buffer's body. This can be useful
- to fuse scheduler nodes with compatible ranges, e.g. (s0*s1*...,) and (s0, s1, s2, ...)
- on CPU by preventing indexing simplifications and obtaining index/reduce ranges for
- the scheduler node compatible with other nodes.
- Optional argument recompute_sizes_body_func can be used to recompute sizes and body
- on the default body. This can be useful to append additional loop transformations.
- """
- (
- (index_size, reduce_size),
- body,
- (index_vars, reduce_vars),
- ) = self.get_default_sizes_body()
- if recompute_sizes_body_func:
- (
- (index_size, reduce_size),
- body,
- (index_vars, reduce_vars),
- ) = recompute_sizes_body_func(
- (index_size, reduce_size), body, (index_vars, reduce_vars)
- )
- index_formulas = [*body.indexing_exprs.values()]
- if extra_indexing_constraints is not None:
- assert (
- isinstance(extra_indexing_constraints, tuple)
- and len(extra_indexing_constraints) == 2
- )
- extra_indexing_ranges, extra_indexing_expr = extra_indexing_constraints
- assert isinstance(extra_indexing_ranges, dict), type(extra_indexing_ranges)
- assert isinstance(extra_indexing_expr, list), type(extra_indexing_expr)
- assert all(isinstance(f, Expr) for f in extra_indexing_expr)
- expected_var_ranges = body.var_ranges
- assert expected_var_ranges == extra_indexing_ranges, (
- expected_var_ranges,
- extra_indexing_ranges,
- )
- # remove already existing expressions
- extra_indexing_expr = [
- e for e in extra_indexing_expr if e not in index_formulas
- ]
- index_formulas += extra_indexing_expr
- memory_addrs = [*body.get_write_exprs()]
- if not V.graph.has_feature(self, BackendFeature.PREFER_STORE_LOOP_ORDER):
- memory_addrs.extend(body.get_read_exprs())
- def simplify_and_reorder(
- x_vars: Sequence[sympy.Symbol],
- support_vars: Sequence[sympy.Symbol],
- sizes: Sequence[int],
- simplify_loops: bool,
- ) -> tuple[
- list[int],
- Callable[[Sequence[int]], Sequence[int]],
- Callable[[Sequence[int]], Sequence[int]],
- ]:
- sizes, reindex0, reindex1 = self._apply_loop_reordering(
- x_vars, support_vars, sizes, memory_addrs
- )
- # for NHWC: reindex0([0,1,2,3]) = [0,2,3,1], reindex1([0,1,2,3]) = [0,3,2,1]
- x_vars = reindex0(x_vars)
- if simplify_loops:
- sizes, reindex2, _prune = V.graph.sizevars._simplify_loops(
- x_vars,
- sizes,
- index_prevent_reordering(index_formulas, x_vars, sizes),
- )
- reindex = fuse_reindexing(reindex1, reindex2)
- else:
- reindex = reindex1
- return sizes, reindex, reindex1
- support_vars = index_vars + reduce_vars
- should_merge_loops = (
- not is_gpu(get_device_type(self)) or not config.loop_ordering_after_fusion
- )
- iter_ranges, iter_reindex, _ = simplify_and_reorder(
- index_vars,
- support_vars,
- index_size,
- should_merge_loops,
- )
- # Like iteration dimensions, we may also want to delay merging reduction dimensions.
- # E.g., if we reduce a tensor [M, N, K] for its M and N dimensions followed by a pointwise
- # kernel, merging M and N dimension too early makes it hard to decide what loop order
- # we should pick for the piontwise kernel so that it is fusible with the reduction.
- reduce_ranges, reduce_reindex, _ = simplify_and_reorder(
- reduce_vars, support_vars, reduce_size, should_merge_loops
- )
- # retrace the loop body with simplification and reordering applied
- (iter_vars, reduce_vars), var_ranges = dependencies.index_vars_no_squeeze(
- iter_ranges,
- reduce_ranges,
- prefix="p",
- )
- body = LoopBody(
- body,
- [iter_reindex(iter_vars), reduce_reindex(reduce_vars)],
- var_ranges,
- iter_vars,
- reduce_vars,
- )
- return (iter_ranges, reduce_ranges), body
- @staticmethod
- def _apply_loop_reordering(
- index_vars: Sequence[sympy.Symbol],
- support_vars: Sequence[sympy.Symbol],
- sizes: Sequence[int],
- memory_addrs: list[sympy.Expr],
- priority_idx: Optional[list[int]] = None,
- ) -> tuple[
- list[int],
- Callable[[Sequence[int]], Sequence[int]],
- Callable[[Sequence[int]], Sequence[int]],
- ]:
- """
- Shuffle the order of loops around to hopefully improve performance.
- """
- from .scheduler import pick_loop_order
- if priority_idx is None:
- priority_idx = []
- try:
- strides = [
- V.graph.sizevars.stride_hints(expr, index_vars, support_vars)
- for expr in memory_addrs
- ]
- assert len(strides) == len(memory_addrs) and len(strides[0]) == len(
- index_vars
- )
- order = list(reversed(pick_loop_order(strides, sizes, priority_idx)))
- except Exception:
- if config.debug:
- log.warning(
- "Did not simplify complex index:\n%s\n%s",
- dict(zip(index_vars, sizes)),
- memory_addrs,
- )
- order = list(range(len(sizes)))
- sizes = [sizes[i] for i in order]
- return sizes, same_reorder(order), inverse_reorder(order)
- def get_reduction_size(self) -> Sequence[Expr]:
- return self.data.get_reduction_size()
- def get_reduction_type(self) -> Optional[str]:
- return self.data.get_reduction_type()
- def is_no_op(self) -> bool:
- return self.data.is_zero_elements()
- def should_allocate(self) -> bool:
- return True
- def constant_to_device(self, device: torch.device) -> IRNode:
- """Move this to a given device. Requires that all reads are to constants."""
- return self.data.constant_to_device(device)
- class TemplateBuffer(OperationBuffer):
- """
- Represents a Triton (in the future other type) of template operator
- that we can fuse an epilogue onto.
- """
- def __init__(
- self,
- layout: OutputSpec,
- inputs: Sequence[IRNode],
- make_kernel_render: Optional[Callable[..., Any]],
- ) -> None:
- super().__init__(name=None, layout=layout)
- self.inputs = InputsKernel.unwrap_storage(inputs)
- self.make_kernel_render = make_kernel_render
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- def get_read_writes(self) -> dependencies.ReadWrites:
- return self.extract_read_writes(normalize=True)
- def extract_read_writes(self, normalize: bool = False) -> dependencies.ReadWrites:
- name = self.get_name()
- indexer = self.get_layout().make_indexer()
- def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
- assert len(rindex) == 0
- return ops.store(name, indexer(index), "fake")
- deps = dependencies.extract_read_writes(
- dummy, self.get_size(), (), normalize=normalize
- )
- for inp in self.inputs:
- assert isinstance(inp, (ReinterpretView, Buffer)), type(inp)
- assert isinstance(inp.layout, Layout), type(inp.layout)
- indexer = inp.layout.make_indexer()
- def dummy(index: Sequence[Any], rindex: Sequence[Any]) -> Any:
- assert len(rindex) == 0
- return ops.load(inp.get_name(), indexer(index))
- deps.reads |= dependencies.extract_read_writes(
- dummy, inp.get_size(), (), normalize=normalize
- ).reads
- return deps
- def get_reduction_size(self) -> Sequence[Expr]:
- return sympy.S.One
- def get_reduction_type(self) -> Optional[str]:
- return None
- def should_allocate(self) -> bool:
- return True
- def simplify_and_reorder(
- self,
- extra_indexing_constraints: Optional[tuple[dict[Any, Any], list[Any]]] = None,
- recompute_sizes_body_func: Optional[Callable[..., Any]] = None,
- ) -> tuple[tuple[Sequence[Expr], list[Expr]], Optional[LoopBody]]:
- return (
- (
- self.get_size(),
- [],
- ),
- None,
- )
- class TritonTemplateBuffer(TemplateBuffer):
- def __init__(
- self,
- layout: Layout,
- inputs: Sequence[IRNode],
- make_kernel_render: Optional[Callable[_P, _T]],
- mutated_inputs: Optional[Iterable[IRNode]] = None,
- allowed_prologue_inps: Optional[OrderedSet[str]] = None,
- ) -> None:
- """
- NOTE:[TritonTemplates with multiple outputs]
- We want the ability for TritonTemplates to output multiple tensors. Triton
- kernels have no notion of outputs and this is done by creating tensors that
- are then mutated by the kernel. Currently our STORE_OUTPUT codegen doesn't
- support creating multinode outputs for triton templates.
- We work around this by creating an extra input buffer during the lowering
- and we mark them as mutated inputs.
- """
- super().__init__(layout, inputs, make_kernel_render)
- self.mutated_inputs = mutated_inputs
- self.outputs: list[Buffer] = [self]
- if mutated_inputs is not None:
- # Ensure that the mutated inputs are only allowed for certain nodes
- allowed_set = (
- torch.ops.higher_order.flex_attention,
- torch.ops.higher_order.flex_attention_backward,
- )
- current_node = V.graph.current_node.target
- assert current_node in allowed_set, (
- f"Mutated inputs are only allowed for {allowed_set} but got {current_node}"
- )
- assert isinstance(self.inputs[0], IRNode), type(self.inputs[0])
- device = self.inputs[0].get_device()
- self.outputs += [
- MutationOutput(NoneLayout(device=device), buf, self)
- for buf in mutated_inputs
- ]
- self.allowed_prologue_inps = (
- allowed_prologue_inps if allowed_prologue_inps else OrderedSet()
- )
- self.subgraph_inps: Optional[list[Optional[Union[IRNode, sympy.Expr]]]] = None
- self.subgraph_outs: Optional[list[Optional[IRNode]]] = None
- @cache_on_self_and_args("TritonTemplateBuffer")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- res = super().get_free_symbol_uses(unbacked_only)
- subgraph_outs = self.subgraph_outs if self.subgraph_outs else []
- subgraph_inps = self.subgraph_inps if self.subgraph_inps else []
- for inp in subgraph_inps:
- if isinstance(inp, sympy.Expr):
- res.update(get_free_symbols(inp, unbacked_only))
- elif isinstance(inp, IRNode):
- res.update(inp.get_free_symbol_uses(unbacked_only))
- else:
- assert inp is None
- for out in subgraph_outs:
- if isinstance(out, IRNode):
- res.update(out.get_free_symbol_uses(unbacked_only))
- else:
- assert out is None
- return res
- def get_outputs(self) -> list[Buffer]:
- return self.outputs
- def get_allowed_prologue_inps(self) -> OrderedSet[str]:
- return self.allowed_prologue_inps
- def __str__(self) -> str:
- out = f"TritonTemplateBuffer(layout={self.layout})"
- return out
- PrimitiveInfoType = Union[int, float, bool, str, list[Union[int, str, float, bool]]]
- class ChoiceCaller:
- """
- Represents a possible choice used in autotune_process.py.
- During autotuning, self.benchmark() is first called to get benchmark result,
- and if this choice is selected, self.output_node() is called to get the output_node.
- Children classes: TritonTemplateCaller, CUDATemplateCaller.
- """
- def __init__(
- self,
- name: str,
- input_nodes: list[Buffer],
- layout: Layout,
- description: str,
- ) -> None:
- super().__init__()
- self.name = name
- self.layout = layout
- self.input_nodes = input_nodes
- # An additional description used to describe the choice (useful for
- # knowing what autotuning is choosing)
- self.description = description
- def benchmark(self, *args: Any, out: torch.Tensor) -> float:
- algo = self.to_callable()
- benchmark_configs = {
- "warmup": autotune_warmup,
- "rep": autotune_rep,
- }
- if config.profile_bandwidth_with_do_bench_using_profiling:
- return do_bench_using_profiling(lambda: algo(*args), **benchmark_configs)
- return benchmarker.benchmark(algo, args, {"out": out}, **benchmark_configs)
- def call_name(self) -> str:
- raise NotImplementedError
- def to_callable(self) -> Callable[..., Any]:
- raise NotImplementedError
- def kernel_hash_key(self) -> str:
- """
- Hash key for the underlying kernel. By default, we assume there are no
- runtime params, so kernel hash key defaults to choice caller's hash key.
- """
- return self.hash_key()
- def hash_key(self) -> str:
- raise NotImplementedError
- def output_node(self) -> Union[TensorBox, ShapeAsConstantBuffer]:
- raise NotImplementedError
- def info_dict(self) -> dict[str, Union[PrimitiveInfoType, list[PrimitiveInfoType]]]:
- """Information returned here is logged to the autotune log file when that is enabled."""
- return {}
- def autoheuristic_id(self) -> str:
- return "unsupported_choice"
- class TritonTemplateCallerBase(ChoiceCaller):
- def get_make_kernel_render(self) -> Any:
- raise NotImplementedError
- class MultiTemplateBuffer(TritonTemplateBuffer):
- """
- Represents a Buffer with multiple backing implementation choices.
- Choices can be TritonTemplates or ExternKernels. During scheduling if there is a potential
- epilogue we will benchmark each of the choices with the epilogue to determine an implementation.
- Otherwise, the fastest base choice will be chosen.
- """
- def __init__(
- self,
- layout: Layout,
- inputs: Sequence[IRNode],
- choice_timings_fn: Callable[[Optional[int]], dict[ChoiceCaller, float]],
- unfiltered_choices: list[ChoiceCaller],
- allowed_prologue_inps: OrderedSet[str],
- ) -> None:
- super().__init__(
- layout=layout,
- inputs=inputs,
- make_kernel_render=None,
- allowed_prologue_inps=allowed_prologue_inps,
- )
- self._choice_timings_fn = choice_timings_fn
- self._choice_timings: dict[Optional[int], dict[ChoiceCaller, float]] = {}
- self.original_inputs = inputs
- self._output_plannable = all(
- isinstance(choice, TritonTemplateCallerBase)
- or (
- isinstance(choice, torch._inductor.select_algorithm.ExternKernelCaller)
- and choice.has_out_variant
- )
- for choice in unfiltered_choices
- )
- self._make_kernel_renders: dict[Optional[int], Any] = {}
- @property
- def output_plannable(self) -> bool:
- """
- Are all possible choices TritonTemplates or Extern Kernels with out variants
- """
- return self._output_plannable
- def choice_timings(
- self, hint_override: Optional[int] = None
- ) -> dict[ChoiceCaller, float]:
- if hint_override not in self._choice_timings:
- self._choice_timings[hint_override] = self._choice_timings_fn(hint_override)
- return self._choice_timings[hint_override]
- @contextlib.contextmanager
- def swap_as_triton_caller(self, caller: TritonTemplateCallerBase) -> Iterator[None]:
- assert isinstance(
- caller, torch._inductor.select_algorithm.TritonTemplateCaller
- ), type(caller)
- assert self.layout == caller.layout
- render = self.make_kernel_render
- self.make_kernel_render = caller.get_make_kernel_render()
- try:
- yield
- finally:
- self.make_kernel_render = render
- def finalize_as_triton_caller(self, caller: TritonTemplateCallerBase) -> None:
- assert isinstance(
- caller, torch._inductor.select_algorithm.TritonTemplateCaller
- ), type(caller)
- assert self.get_size() == caller.layout.size
- assert self.get_stride() == caller.layout.stride
- self.make_kernel_render = caller.get_make_kernel_render()
- def get_min_choice(
- self, hint_override: Optional[int] = None
- ) -> tuple[ChoiceCaller, float]:
- timings = self.choice_timings(hint_override=hint_override)
- min_choice = min(timings, key=timings.get) # type: ignore[arg-type]
- return (min_choice, timings[min_choice])
- def finalize_as_triton_callers(
- self, callers: dict[Optional[int], TritonTemplateCallerBase]
- ) -> None:
- """Finalize with multiple callers for different hint overrides"""
- for hint_override, caller in callers.items():
- self._make_kernel_renders[hint_override] = caller.get_make_kernel_render()
- # Set the default to be the one without hint override
- self.make_kernel_render = self._make_kernel_renders[None]
- class CUDATemplateBuffer(TemplateBuffer):
- def __init__(
- self,
- layout: Layout,
- inputs: Sequence[IRNode],
- make_kernel_render: Callable[_P, _T],
- workspace_size: int,
- template: CUDATemplate,
- supports_epilogue_fusion: bool,
- ) -> None:
- super().__init__(layout, inputs, make_kernel_render)
- # Global memory (in bytes) needed for this template.
- self.workspace_size = workspace_size
- self.template = template
- self.supports_epilogue_fusion = supports_epilogue_fusion
- def get_workspace_size(self) -> int:
- return self.workspace_size if self.workspace_size is not None else 0
- def emulate_store_fn(self) -> None:
- for output in self.get_outputs():
- ops.store(output.get_name(), None, None)
- class CppTemplateBuffer(TemplateBuffer):
- def __init__(
- self,
- layout: Layout,
- inputs: Sequence[IRNode],
- make_kernel_render: Callable[_P, _T],
- template: CUDATemplate,
- choice: Any,
- ) -> None:
- super().__init__(layout, inputs, make_kernel_render)
- self.template = template
- self.choice = choice
- self.outputs: Optional[list[Buffer]] = None
- def get_layout(self) -> Layout:
- if isinstance(self.layout, MultiOutputLayout):
- assert isinstance(self.outputs, Iterable), type(self.outputs)
- first_output = self.outputs[0]
- assert isinstance(first_output, Buffer), type(first_output)
- layout = first_output.layout
- assert isinstance(layout, Layout), type(layout)
- return layout
- else:
- return super().get_layout()
- class CuteDSLTemplateBuffer(TemplateBuffer):
- """
- Buffer for CuteDSL (CUTLASS Python DSL) template kernels.
- Similar to other template buffers but specialized for CuteDSL operations.
- """
- def __init__(
- self,
- layout: Layout,
- inputs: Sequence[IRNode],
- make_kernel_render: Callable[_P, _T],
- template: Any,
- mutated_inputs: Optional[Iterable[IRNode]] = None,
- ) -> None:
- super().__init__(layout, inputs, make_kernel_render)
- self.template = template
- self.mutated_inputs = mutated_inputs
- self.outputs: list[Buffer] = [self]
- if mutated_inputs is not None:
- assert isinstance(self.inputs[0], IRNode), type(self.inputs[0])
- device = self.inputs[0].get_device()
- self.outputs += [
- MutationOutput(NoneLayout(device=device), buf, self)
- for buf in mutated_inputs
- ]
- def get_outputs(self) -> list[Buffer]:
- return self.outputs
- def is_node_sequence(
- nodes: Sequence[Union[IRNode, Sequence[IRNode]]],
- ) -> TypeIs[Sequence[IRNode]]:
- return all(isinstance(n, IRNode) for n in nodes)
- @ir_dataclass(frozen=False)
- class InputsKernel(OperationBuffer):
- inputs: Sequence[Union[IRNode, Sequence[IRNode]]]
- def input_name(self, i: int) -> str:
- input = self.inputs[i]
- assert isinstance(input, IRNode)
- return input.get_name()
- def get_read_writes(self) -> dependencies.ReadWrites:
- reads = OrderedSet[dependencies.Dep]()
- StarDep = dependencies.StarDep
- for input in self.inputs:
- if isinstance(input, Sequence):
- reads.update(StarDep(x.get_name()) for x in input)
- elif isinstance(input, ShapeAsConstantBuffer):
- # Skip creating dependency for symbolics as they're visible globally
- continue
- else:
- reads.add(StarDep(input.get_name()))
- writes = OrderedSet[dependencies.Dep](
- StarDep(buf.get_name()) for buf in self.get_outputs()
- )
- return dependencies.ReadWrites(
- reads=reads,
- writes=writes,
- index_exprs=OrderedSet(),
- )
- def get_reads(self) -> OrderedSet[Dep]:
- return self.get_read_writes().reads
- @classmethod
- def unwrap_storage_for_input(cls, x: IRNode) -> IRNode:
- if isinstance(x, TensorBox):
- x = x.data
- if isinstance(x, StorageBox):
- x = x.data
- if isinstance(x, BaseView) and not isinstance(x, ReinterpretView):
- x = ExternKernel.realize_input(x)
- if isinstance(x, TensorBox):
- # when converting to ReinterpretView fails in the
- # realize_input call above, the result will be wrapped
- # into TensorBox / StorageBox pair as a result of the
- # cls.copy_input call; so we should unwrap recursively
- return cls.unwrap_storage_for_input(x)
- if isinstance(x, TorchBindObject):
- return x
- assert isinstance(x, (Buffer, ReinterpretView)), type(x)
- return x
- @staticmethod
- def unwrap_storage(
- inputs: Sequence[Union[IRNode, Sequence[IRNode]]],
- ) -> list[Union[IRNode, Sequence[IRNode]]]:
- inputs_new: list[Union[IRNode, Sequence[IRNode]]] = []
- for x in inputs:
- if isinstance(x, Sequence):
- x = [InputsKernel.unwrap_storage_for_input(i) for i in x]
- else:
- x = InputsKernel.unwrap_storage_for_input(x)
- inputs_new.append(x)
- return inputs_new
- def is_extern(self) -> bool:
- return True
- def num_reads(self) -> int:
- return 1
- @cache_on_self_and_args("InputsKernel")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- r = OrderedSet[sympy.Symbol]()
- for inp in self.inputs:
- if isinstance(inp, IRNode):
- r |= inp.get_free_symbol_uses(unbacked_only)
- else:
- for inner_inp in inp:
- r |= inner_inp.get_free_symbol_uses(unbacked_only)
- return r
- class NopKernel(InputsKernel):
- def is_no_op(self) -> bool:
- return True
- def get_reads(self) -> OrderedSet[Dep]:
- return OrderedSet()
- class ConcatKernel(NopKernel):
- """
- There isn't actually a real kernel for concat, we just change the
- storage for the upstream data.
- """
- @classmethod
- def create(cls, inputs: Sequence[IRNode], dim: int) -> StorageBox:
- """
- Create the concat kernel from inputs
- """
- device = inputs[0].get_device()
- dtype = inputs[0].get_dtype()
- new_size = list(inputs[0].get_size())
- offsets_start = [0]
- offsets_end = [new_size[dim]]
- assert 0 <= dim < len(new_size)
- for i in range(1, len(inputs)):
- input_size = inputs[i].get_size()
- offsets_start.append(new_size[dim])
- assert len(input_size) == len(new_size)
- assert inputs[i].get_dtype() == dtype
- assert inputs[i].get_device() == device
- for j in range(len(new_size)):
- if j == dim:
- new_size[j] = new_size[j] + input_size[j]
- else:
- new_size[j] = V.graph.sizevars.check_equals_and_simplify(
- new_size[j], input_size[j]
- )
- offsets_end.append(new_size[dim])
- output_stride: Sequence[int] = FlexibleLayout.contiguous_strides(new_size)
- if config.comprehensive_padding:
- # Ensure the output stride matches the alignment requirements
- output_stride = Layout._pad_strides(
- output_stride, new_size, inputs[0].dtype
- )
- # If any of the inputs is in CL format, use CL format for the output
- for i in range(len(inputs)):
- x = inputs[i]
- if is_storage_and_layout(x):
- layout = x.get_layout()
- if isinstance(
- layout, FixedLayout
- ) and Layout.is_channels_last_contiguous(layout.size, layout.stride):
- # use CL stride for the output
- output_stride = make_channels_last_strides_for(new_size)
- break
- any_input_is_storage_and_layout = any(is_storage_and_layout(x) for x in inputs)
- fx_node_args = V.graph.current_node.args[0]
- assert isinstance(fx_node_args, list), type(fx_node_args)
- # If any of the inputs has meta tensor and the meta tensor is in CL format, use CL format for the output
- if any_input_is_storage_and_layout is False and any(
- "val" in arg.meta
- and (
- arg.meta["val"].is_contiguous(memory_format=torch.channels_last)
- or arg.meta["val"].is_contiguous(memory_format=torch.channels_last_3d)
- )
- for arg in fx_node_args
- ):
- output_stride = make_channels_last_strides_for(new_size)
- is_pinned = all(
- is_storage_and_layout(x) and x.get_layout().is_pinned for x in inputs
- )
- assert device is not None
- concat_kernel = ConcatKernel(
- name=None,
- layout=FixedLayout(
- device=device,
- dtype=dtype,
- size=new_size,
- stride=output_stride,
- is_pinned=is_pinned,
- ),
- inputs=[],
- )
- kernel = StorageBox(concat_kernel)
- op_names = []
- for i, inp in enumerate(inputs):
- assert isinstance(inp, (BaseView, MutableBox)), type(inp)
- input_buffer = cls.realize_into(
- inp,
- SliceView.create(
- kernel, dim, offsets_start[i], offsets_end[i], clamp=False
- ),
- )
- assert isinstance(input_buffer, Buffer), type(input_buffer)
- assert isinstance(concat_kernel.inputs, list), type(concat_kernel.inputs)
- concat_kernel.inputs.append(input_buffer)
- if isinstance(inp.data, BaseView):
- input_unwrapped = inp.data.unwrap_view()
- else:
- input_unwrapped = inp.data
- if (
- isinstance(input_unwrapped, StorageBox)
- and input_unwrapped.is_input_buffer()
- and (dev := inp.get_device()) is not None
- and is_gpu(dev.type)
- and not is_dynamic(input_buffer)
- ):
- op_names.append(input_buffer.get_operation_name())
- if len(op_names) > 1 and V.graph.has_feature(device, BackendFeature.FOREACH):
- V.graph.register_operation_list(op_names)
- concat_kernel.name = V.graph.register_buffer(concat_kernel)
- concat_kernel.inputs = cls.unwrap_storage(concat_kernel.inputs)
- V.graph.register_operation(concat_kernel)
- return kernel
- @classmethod
- def can_realize_into_without_copy(
- cls, src: IRNode, dst: Optional[IRNode] = None
- ) -> bool:
- if isinstance(src, TensorBox):
- # unwrap a TensorBox
- return cls.can_realize_into_without_copy(src.data, dst)
- assert isinstance(src, (BaseView, StorageBox)), type(src)
- if isinstance(src.data, MultiTemplateBuffer):
- if (
- not isinstance(src.data.layout, FixedLayout)
- or not src.data.output_plannable
- ):
- return False
- # we call can_realize_into_without_copy in cat lowering before we've decided
- # on output format, optimistically assume layout matches
- if dst is None:
- return True
- # otherwise, check equality of layouts
- if not len(src.get_stride()) == len(dst.get_stride()):
- return False
- return all(
- V.graph.sizevars.statically_known_equals(s1, s2)
- for s1, s2 in zip(src.get_stride(), dst.get_stride())
- )
- return (
- hasattr(src.data, "layout")
- and isinstance(src.data.layout, FlexibleLayout)
- and not isinstance(src.data, ExternKernelAlloc)
- )
- @cache_on_self_and_args("ConcatKernel")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return NopKernel.get_free_symbol_uses(self, unbacked_only)
- @classmethod
- def realize_into(cls, src: IRNode, dst: IRNode) -> IRNode:
- # Attempt to turn this into a ReinterpretView rather than assert.
- # This has concessions around layout, as as_storage_and_layout
- # can cause us to go from flexible to fixed layout.
- if not isinstance(dst, ReinterpretView):
- if is_storage_and_layout(dst):
- storage, layout = as_storage_and_layout(dst)
- dst = ReinterpretView(data=storage, layout=layout)
- assert isinstance(dst, ReinterpretView), type(dst)
- if isinstance(src, TensorBox):
- # unwrap a TensorBox
- return cls.realize_into(src.data, dst)
- if isinstance(src, StorageBox):
- src.realize()
- # ExternKernelAlloc has specific requirements for output layout, should create a copy
- assert hasattr(src.data, "layout")
- if cls.can_realize_into_without_copy(src, dst):
- src.data.layout = NonOwningLayout(dst)
- return src.data
- # introduce a copy
- pw = Pointwise.create(
- device=src.get_device(),
- dtype=src.get_dtype(),
- inner_fn=src.make_loader(),
- ranges=[
- V.graph.sizevars.check_equals_and_simplify(a, b)
- for a, b in zip(src.get_size(), dst.get_size())
- ],
- )
- return cls.realize_into(pw, dst)
- def should_allocate(self) -> bool:
- return True
- @ir_dataclass(frozen=False)
- class ExternKernel(InputsKernel):
- """
- A class that represents Kernels which are not directly lowered to Inductor
- Loop Level IR, such as custom operators, or aten operators which we fallback to.
- """
- constant_args: Sequence[Any] = ()
- kwargs: dict[str, Any] = dataclasses.field(default_factory=dict)
- output_view: Optional[ReinterpretView] = None
- python_kernel_name: Optional[str] = None
- cpp_kernel_name: Optional[str] = None
- # FIXME: in some cases we sill need to explicitly pass in ordered_kwargs_for_cpp_kernel
- # We shouldn't need to do this since the information can be retrieved from op_overload._schema.
- ordered_kwargs_for_cpp_kernel: Iterable[str] = dataclasses.field(
- default_factory=list
- )
- op_overload: Optional[_OpOverloads] = None
- arg_properties: Optional[list[dict[str, Any]]] = None
- allarg_properties: dict[str, dict[str, Any]] = dataclasses.field(
- default_factory=dict
- )
- kwarg_properties: Optional[dict[str, dict[str, Any]]] = None
- unbacked_bindings: dict[sympy.Symbol, pytree.KeyPath] = dataclasses.field(
- default_factory=dict
- )
- mutation_outputs: list[MutationOutput] = dataclasses.field(default_factory=list)
- def __init__(
- self,
- name: Optional[str],
- layout: OutputSpec,
- inputs: Sequence[Union[IRNode, Sequence[IRNode]]],
- constant_args: Sequence[Any] = (),
- kwargs: Optional[dict[str, Any]] = None,
- output_view: Optional[ReinterpretView] = None,
- python_kernel_name: Optional[str] = None,
- cpp_kernel_name: Optional[str] = None,
- ordered_kwargs_for_cpp_kernel: Iterable[str] = (),
- op_overload: Optional[_OpOverloads] = None,
- ) -> None:
- super().__init__(
- name=name,
- layout=layout,
- inputs=inputs,
- )
- self.constant_args = constant_args
- self.kwargs = kwargs if kwargs else {}
- self.output_view = output_view
- self.op_overload = op_overload
- self.set_cpp_kernel_name(cpp_kernel_name)
- self.set_python_kernel_name(python_kernel_name)
- self.ordered_kwargs_for_cpp_kernel = ordered_kwargs_for_cpp_kernel
- self.collect_arg_kwarg_properties()
- self.unbacked_bindings = {}
- self.mutation_outputs = []
- self.fx_node = V.graph.current_node
- def get_outputs(self) -> list[Buffer]:
- return [self, *self.mutation_outputs]
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def collect_arg_kwarg_properties(self) -> None:
- # if self.op_overload is torch._ops.OpOverload, we can use its schema to collect additional
- # information for args and kwargs, e.g. type and default value, to help with the cpp wrapper codegen
- self.arg_properties = (
- [
- {
- "name": x.name,
- "type": x.real_type,
- "default_value": x.default_value,
- }
- for x in self.op_overload._schema.arguments
- if not x.kwarg_only
- ]
- if isinstance(self.op_overload, torch._ops.OpOverload)
- else [{} for i in range(len(self.inputs))]
- )
- self.allarg_properties = (
- {
- x.name: {"type": x.real_type, "default_value": x.default_value}
- for x in self.op_overload._schema.arguments
- }
- if isinstance(self.op_overload, torch._ops.OpOverload)
- else {}
- )
- # FIXME: self.kwargs does not always match kwargs defined in schema, so sometimes
- # ordered_kwargs_for_cpp_kernel is explicitly passed in.
- if isinstance(self.op_overload, torch._ops.OpOverload):
- if not self.ordered_kwargs_for_cpp_kernel:
- self.ordered_kwargs_for_cpp_kernel = [
- x.name for x in self.op_overload._schema.arguments if x.kwarg_only
- ]
- self.schema_kwargs = [
- x for x in self.op_overload._schema.arguments if x.kwarg_only
- ]
- else:
- self.schema_kwargs = []
- def decide_layout(self) -> None:
- if isinstance(self.layout, FlexibleLayout):
- self.apply_constraint()
- self.freeze_layout()
- def codegen_comment(self, wrapper: PythonWrapperCodegen) -> None:
- origin_str, _detailed_origin_str = get_kernel_metadata(self, wrapper)
- if origin_str:
- wrapper.make_comment(origin_str)
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- raise NotImplementedError
- def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
- self.cpp_kernel_name = cpp_kernel_name
- if not V.graph.cpp_wrapper or not isinstance(
- self.op_overload, torch._ops.OpOverload
- ):
- return
- kernel = self.op_overload
- if self.cpp_kernel_name is None:
- # Try to construct cpp_kernel_name from op_overload
- if kernel.namespace == "aten":
- # Calling with the default kernel name can lead to ambiguous behavior like the following example.
- # repeat_interleave(const at::Tensor & repeats, std::optional<int64_t> output_size=std::nullopt)
- # repeat_interleave(const at::Tensor & self, int64_t repeats,
- # std::optional<int64_t> dim=std::nullopt, std::optional<int64_t> output_size=std::nullopt)
- opname = (
- kernel.__name__.split(".")[0]
- if kernel._overloadname == "default"
- else kernel.__name__.replace(".", "_")
- )
- self.cpp_kernel_name = f"at::_ops::{opname}::call"
- else:
- self.cpp_kernel_name = kernel._schema.name
- def set_python_kernel_name(self, python_kernel_name: Optional[str]) -> None:
- self.python_kernel_name = python_kernel_name
- if python_kernel_name is not None:
- return
- kernel = self.op_overload
- if kernel is None:
- pass
- elif isinstance(kernel, torch._ops.HigherOrderOperator):
- self.python_kernel_name = f"torch.ops.higher_order.{kernel.__name__}"
- else:
- self.python_kernel_name = (
- f"{kernel.__module__.replace('._ops.', '.ops.')}.{kernel.__name__}"
- )
- def get_kernel_name(self) -> str:
- from .codegen.cpp_wrapper_cpu import CppWrapperCpu
- device = d.type if (d := self.get_device()) else V.graph.device_type
- if V.graph.fx_wrapper:
- assert self.python_kernel_name is not None
- return self.python_kernel_name
- elif V.graph.cpp_wrapper:
- assert isinstance(V.graph.wrapper_code, CppWrapperCpu), type(
- V.graph.wrapper_code
- )
- assert self.cpp_kernel_name is not None
- return V.graph.wrapper_code.get_c_shim_func_name(
- self.cpp_kernel_name, device
- )
- else:
- assert self.python_kernel_name is not None
- return self.python_kernel_name
- @staticmethod
- def copy_input(x: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]:
- pw = Pointwise.create(
- device=x.get_device(),
- dtype=x.get_dtype(),
- inner_fn=x.make_loader(),
- ranges=x.get_size(),
- origin_node=x.get_origin_node(),
- traceback=x.get_traceback(),
- )
- pw.realize()
- return pw
- @classmethod
- def process_kernel(
- cls, kernel: _OpOverloads, *args: Any, **kwargs: Any
- ) -> tuple[
- Any,
- list[Any],
- list[Any],
- Callable[[Any, Any], Any],
- Optional[dict[sympy.Symbol, pytree.KeyPath]],
- ]:
- binded_args = {"args": args, "kwargs": kwargs}
- args_flat, args_spec = pytree.tree_flatten(binded_args)
- is_arg_tensor = []
- # tensor_args can be either tensor or torchbind objects
- tensor_args = []
- non_tensor_args: list[Any] = []
- for arg in args_flat:
- is_arg_tensor.append(
- isinstance(arg, IRNode) and not isinstance(arg, GeneratorState)
- )
- if is_arg_tensor[-1]:
- tensor_args.append(arg)
- else:
- if isinstance(arg, Expr):
- arg = V.graph.sizevars.shape_env.create_symintnode(arg, hint=None)
- non_tensor_args.append(arg)
- def unflatten_args(
- new_tensor_args: Sequence[_T], new_non_tensor_args: Sequence[_T]
- ) -> tuple[list[_T], dict[str, _T]]:
- result = []
- it_tensors = iter(new_tensor_args)
- it_non_tensors = iter(new_non_tensor_args)
- for is_tensor in is_arg_tensor:
- if is_tensor:
- result.append(next(it_tensors))
- else:
- result.append(next(it_non_tensors))
- r = pytree.tree_unflatten(result, args_spec)
- return r.get("args", []), r.get("kwargs", {})
- tensor_args = [cls.realize_input(x) for x in tensor_args]
- # freeze layout otherwise our output stride calculation might
- # become incorrect
- for x in tensor_args:
- if is_storage_and_layout(x):
- as_storage_and_layout(x, freeze=True)
- # Rerun fake tensor propagation, because Inductor may have changed the
- # strides of inputs and we need to determine accurately what the
- # output stride will be.
- example_args: list[
- Union[
- torch.Tensor, torch._C.ScriptObject, FakeScriptObject, torch.Generator
- ]
- ] = []
- # We need to retain the constant values of fake tensors that we originally
- # propagated the graph with, because for some operators running without a
- # constant would trigger an error / DataDependentException
- for x in tensor_args:
- # if x is a view of a constant, we need to realize the view
- # (we can't pass the constant into the kernel directly)
- if not isinstance(x, BaseView) and x.get_name() in V.graph.constants:
- example_args.append(V.graph.constants[x.get_name()])
- elif (
- not isinstance(x, BaseView)
- and x.get_name() in V.graph.torchbind_constants
- ):
- example_args.append(V.graph.torchbind_constants[x.get_name()])
- elif isinstance(x, TorchBindObject):
- example_args.append(x.get_value())
- elif isinstance(x, torch._inductor.ir.GeneratorState):
- device_index = x.device.index
- assert x.device.type == "cuda" and device_index is not None
- example_args.append(
- torch.cuda.default_generators[device_index].clone_state()
- )
- else:
- example_args.append(ir_node_to_tensor(x, guard_shape=True))
- new_args, new_kwargs = unflatten_args(example_args, non_tensor_args)
- example_output = kernel(*new_args, **new_kwargs)
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None
- if shape_env := V.fake_mode.shape_env:
- node_meta_val = V.current_node.meta.get("val")
- ctx: AbstractContextManager[None] = nullcontext()
- if V.current_node.target == torch._higher_order_ops.effects.with_effects:
- # remove the first effect token in meta["val"] and meta["unbacked_bindings"]
- node_meta_val = node_meta_val[1]
- ctx = _remove_effect_token_unbacked_bindings(V.current_node)
- with ctx:
- rebind_unbacked(shape_env, V.current_node, example_output)
- unbacked_bindings = compute_unbacked_bindings(
- shape_env, example_output, node_meta_val
- )
- example_out_li = (
- [example_output]
- if not isinstance(example_output, (list, tuple))
- else example_output
- )
- for t in example_out_li:
- if isinstance(t, torch.Tensor) and t.is_sparse:
- msg = "sparsity not handled. Please file issue for sparse inference weights."
- if stack_trace := V.graph.current_node.meta.get("stack_trace", None):
- msg = f"{msg} Found from : \n {stack_trace}"
- V.graph.disable_cudagraphs_reason = msg
- return (
- example_output,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- unbacked_bindings,
- )
- @classmethod
- def convert_to_reinterpret_view(cls, x: IRNode) -> ReinterpretView:
- """
- In order to pass this to an extern kernel we need a
- ReinterpretView not a View. This allows us to avoid some
- unneeded copies.
- """
- assert isinstance(x, BaseView), type(x)
- if isinstance(x, ReinterpretView):
- return x
- # NOTE: Don't use extract_read_writes here as it fails when
- # make_loader() inlines the computation
- x_unwrap_view = x.unwrap_view()
- buf = V.graph.get_buffer(x_unwrap_view.get_name())
- assert buf is not None
- x_unwrap_view_fx_node = buf.get_origin_node()
- # Prefer channels last format according to how the format is set from eager.
- if (
- x_unwrap_view_fx_node is not None
- and "val" in x_unwrap_view_fx_node.meta
- and isinstance(x_unwrap_view, (ReinterpretView, Buffer, MutableBox))
- and isinstance(x_unwrap_view.layout, FlexibleLayout)
- and (
- x_unwrap_view_fx_node.meta["val"].is_contiguous(
- memory_format=torch.channels_last
- )
- or x_unwrap_view_fx_node.meta["val"].is_contiguous(
- memory_format=torch.channels_last_3d
- )
- )
- ):
- x_unwrap_view.freeze_layout_with_same_order(
- make_channels_last_strides_for(x_unwrap_view.get_size())
- )
- else:
- x_unwrap_view.freeze_layout()
- index_args, var_ranges = dependencies.index_vars_squeeze(
- x.get_size(), prefix="r"
- )
- range_vars = index_args[0]
- index = x.make_indexer()(range_vars)
- index = V.graph.sizevars.simplify_with_ranges(index, var_ranges)
- strides = V.graph.sizevars.stride_vars(index, range_vars)
- offset = V.graph.sizevars.offset_var(index, range_vars)
- expected = sympy_dot(range_vars, strides) + offset
- if index != expected:
- log.debug(
- "convert_to_reinterpret_view failed: stride=%s offset=%s index=%s",
- strides,
- offset,
- index,
- )
- raise NotImplementedError
- return ReinterpretView(
- data=x.data,
- layout=FixedLayout(
- device=x.get_device_or_error(),
- dtype=x.get_dtype(),
- size=x.get_size(),
- stride=strides,
- offset=offset,
- is_pinned=False,
- ),
- )
- @classmethod
- def realize_input(cls, x: IRNode) -> IRNode:
- if x is None:
- return NoneAsConstantBuffer()
- if isinstance(x, (Expr, sympy.logic.boolalg.Boolean, int)):
- return ShapeAsConstantBuffer(expr=x)
- if isinstance(x, Constant):
- return V.graph.add_tensor_constant(
- torch.tensor(x.value, dtype=x.get_dtype(), device=x.get_device())
- )
- if isinstance(x, ConstantBuffer):
- return x
- if isinstance(x, TensorBox):
- return cls.realize_input(x.data)
- if isinstance(x, ReinterpretView):
- return ReinterpretView(
- data=cls.realize_input(x.data), layout=x.get_layout()
- )
- if isinstance(x, BaseView):
- x.realize()
- if is_storage_and_layout(x.unwrap_view()):
- try:
- return cls.convert_to_reinterpret_view(x)
- except NotImplementedError:
- pass
- if isinstance(x, StorageBox):
- # TODO(jansel): impose layout preference on realized buffer
- x.realize()
- return x
- if isinstance(x, (NonTensorObj, ShapeAsConstantBuffer)):
- return x
- return cls.copy_input(x)
- @classmethod
- def require_stride1(cls, x: IRNode) -> IRNode:
- if is_storage_and_layout(x):
- if len(x.get_stride()) == 0:
- return x
- for stride in x.get_stride():
- if stride == 1:
- return x
- return cls.copy_input(x)
- @classmethod
- def require_strides(
- cls,
- x: IRNode,
- order: Optional[Sequence[int]] = None,
- exact_strides: Optional[Sequence[_IntLike]] = None,
- allow_padding: bool = False,
- ) -> IRNode:
- assert order is not None or exact_strides is not None
- # Layout generally doesn't matter, but some consuming external ops might have requirements
- if x.get_numel() in (0, 1) and not exact_strides:
- return x
- # require x to have the layout
- if is_storage_and_layout(x):
- if isinstance(x.get_layout(), FlexibleLayout):
- if order:
- # If the the FlexibleLayout already has the size and stride in the required order,
- # freeze it to a FixedLayout by using its current size and stride.
- # The behavior of using its current size and stride or the given order can be different
- # if the size and stride has ambiguilty, for example for a 4D input where the iC = 1:
- # size=[s0, 1, 28, 28], stride=[784, 784, 28, 1]. If the required order is [3, 0, 2, 1] (channels last),
- # the current size and stride already satisfies this order.
- # However by freezing it to the required order, the layout will be changed to:
- # size=[s0, 1, 28, 28], stride=[784, 1, 28, 1]), which is not actually necessary.
- use_current_stride_order = is_stride_order_storage_and_layout(
- x, order
- ) and not free_unbacked_symbols(x.get_layout().stride)
- # fix flexiblelayout to be FixedLayout with stride_order
- as_storage_and_layout(
- x,
- freeze=True,
- want_contiguous=False,
- stride_order=(
- get_stride_order(
- V.graph.sizevars.size_hints_or_throw(
- x.get_layout().stride
- )
- )
- if use_current_stride_order
- else order
- ),
- allow_padding=allow_padding,
- )
- return x
- else:
- # If the exact_strides is given, freeze the FlexibleLayout to a FixedLayout with the exact_strides.
- as_storage_and_layout(
- x,
- freeze=True,
- want_contiguous=False,
- stride_order=None,
- allow_padding=allow_padding,
- exact_strides=exact_strides,
- )
- return x
- elif isinstance(x.get_layout(), (FixedLayout, NonOwningLayout)) and (
- (order and x.get_layout().is_stride_ordered(order))
- or (
- exact_strides
- and significant_strides_equal(
- exact_strides, x.get_layout().stride, x.get_size()
- )
- )
- ):
- return (
- try_match_insignificant_strides(x, exact_strides)
- if exact_strides is not None
- else x
- )
- elif isinstance(
- (mutation_layout := x.get_layout()), MutationLayoutSHOULDREMOVE
- ):
- if isinstance(
- (real_layout := mutation_layout.real_layout()), FlexibleLayout
- ):
- raise AssertionError(
- "the MutationLayoutSHOULDREMOVE's real layout shouldn't be FlexibleLayout"
- )
- elif isinstance(real_layout, FixedLayout) and (
- (order and real_layout.is_stride_ordered(order))
- or (
- exact_strides
- and significant_strides_equal(
- exact_strides, real_layout.stride, x.get_size()
- )
- )
- ):
- return x
- # TODO - Storage to InputBuffer
- if isinstance(x, InputBuffer) and (
- (order and x.get_layout().is_stride_ordered(order))
- or (
- exact_strides
- and significant_strides_equal(
- exact_strides, x.get_layout().stride, x.get_size()
- )
- )
- ):
- return x
- if (
- isinstance(x, TensorBox)
- and isinstance(x.data, BaseView)
- and not isinstance(x.data, ReinterpretView)
- and is_storage_and_layout(unwrap_view := x.unwrap_view())
- and hasattr(unwrap_view, "data")
- and not isinstance(unwrap_view.data, ExternKernelAlloc)
- ):
- try:
- x.data = cls.convert_to_reinterpret_view(x.data)
- if order:
- return cls.require_stride_order(
- x, order, allow_padding=allow_padding
- )
- elif exact_strides:
- return cls.require_exact_strides(
- x, exact_strides, allow_padding=allow_padding
- )
- except NotImplementedError:
- pass
- # Preserve ExpandView representation that would be lost during copy_input
- # Without representation of the expand in inductor IR, in codegen we end up
- # launching a grid for the full size tensor and doing redundant computation
- # across expanded dims.
- # TODO: could also be good to have a codegen fix to recognize overlapping elements
- expanded_dims: Optional[list[int]] = None
- orig_size = x.get_size()
- if exact_strides is not None:
- sizevars = V.graph.sizevars
- expanded_dims = [
- i
- for i in range(len(x.get_size()))
- if sizevars.statically_known_equals(exact_strides[i], 0)
- and sizevars.statically_known_geq(x.get_size()[i], 2)
- ]
- for dim in expanded_dims:
- x = torch._inductor.lowering.slice_(x, dim, 0, 1)
- # Although this is a clone, inductor is good about fusing clones into previous
- # operations if they weren't realized and their layouts were flexible.
- x = cls.copy_input(x)
- as_storage_and_layout(
- x,
- freeze=True,
- want_contiguous=False,
- stride_order=order,
- allow_padding=allow_padding,
- exact_strides=exact_strides,
- )
- if order:
- assert is_stride_order_storage_and_layout(x, order)
- elif expanded_dims:
- assert orig_size is not None and exact_strides is not None
- x = torch._inductor.lowering.expand(x, orig_size)
- # the expand will sometimes may change insignificant strides, so match them back
- return try_match_insignificant_strides(x, exact_strides)
- return x
- @classmethod
- def require_exact_strides(
- cls, x: IRNode, exact_strides: Sequence[_IntLike], allow_padding: bool = False
- ) -> IRNode:
- return cls.require_strides(
- x, exact_strides=exact_strides, allow_padding=allow_padding
- )
- @classmethod
- def require_stride_order(
- cls, x: IRNode, order: Sequence[int], allow_padding: bool = False
- ) -> IRNode:
- return cls.require_strides(x, order=order, allow_padding=allow_padding)
- @classmethod
- def require_channels_last(cls, x: IRNode) -> IRNode:
- return cls.require_stride_order(x, NHWC_STRIDE_ORDER)
- @classmethod
- def require_channels_last_3d(cls, x: IRNode) -> IRNode:
- return cls.require_stride_order(x, NHWDC_STRIDE_ORDER)
- @classmethod
- def require_contiguous(cls, x: IRNode) -> IRNode:
- def is_mkldnn_tensor(x: IRNode) -> bool:
- try:
- name = x.get_name()
- except (AttributeError, NotImplementedError):
- return False
- return name in V.graph.constants and V.graph.constants[name].is_mkldnn
- # TODO move this to the more proper places
- if is_mkldnn_tensor(x):
- return x
- else:
- return cls.require_exact_strides(
- x, FlexibleLayout.contiguous_strides(x.get_size())
- )
- @classmethod
- def require_contiguous_strides(cls, x: IRNode) -> IRNode:
- # TODO: combine this with require_contiguous after
- # https://github.com/pytorch/pytorch/pull/148235 lands.
- return cls.require_exact_strides(
- x, FlexibleLayout.contiguous_strides(x.get_size())
- )
- def apply_constraint(self) -> None:
- pass
- def fill_non_provided_args(
- self, args: Sequence[Any], kwargs: dict[str, Any]
- ) -> Sequence[Any]:
- # Previously, we want to maintain forward-compatibility by skipping
- # default args in the serialized artifacts in fbcode. However,
- # some of our shim interfaces require default values being OrderedSet.
- # Discussed with Sherlock offline and we decided to allow serializing
- # default args into the C++ wrapper code for now. We will refine this
- # part if we see real FC requirement. More details related to FC
- # can be found at:
- # https://docs.google.com/document/d/1FzWm-sHYwmRi3x_g036kOxd99KaYquUsA-L5JwOn8ys/edit?usp=sharing
- assert isinstance(args, Sequence), type(args)
- if not isinstance(args, list):
- args = list(args)
- assert self.arg_properties, "ExternKernel.arg_properties should not be empty"
- n_args = len(args)
- n_pos_args = len(self.arg_properties)
- # For cpp wrapper, if some positional args are not provided, we need to check
- # if they're in the kwargs or use their default value
- if n_args < n_pos_args:
- log.debug(
- "%s has %d unprovided positional arguments. "
- "Will check if they are in the keyword arguments or will use default values.",
- self.op_overload,
- n_pos_args - n_args,
- )
- for i in range(n_args, n_pos_args):
- arg_name = self.arg_properties[i]["name"]
- args.append(
- kwargs[arg_name]
- if arg_name in kwargs
- else self.arg_properties[i]["default_value"]
- )
- return args
- def codegen_const_args(self, names: Optional[list[str]] = None) -> list[str]:
- if V.graph.cpp_wrapper:
- result = []
- # Aten ops follow the convention that tensor args are before non-tensor args,
- # in which case the following 'len(self.inputs) + i' logic works. But this
- # may not be true for other ops, and if that is the case, caller needs to
- # pass in a list of const arg names for arg_properties lookup.
- name_to_arg_properties = None
- if names and self.arg_properties:
- assert len(self.constant_args) == len(names), (
- "names passed to codegen_const_args does not match self.constant_args"
- )
- name_to_arg_properties = {
- arg.get("name"): arg for arg in self.arg_properties
- }
- for i, x in enumerate(self.constant_args):
- if name_to_arg_properties is not None:
- assert names is not None
- prop = name_to_arg_properties.get(names[i])
- type_ = prop.get("type") if prop else None
- else:
- idx = len(self.inputs) + i
- type_ = (
- self.arg_properties[idx].get("type")
- if self.arg_properties and idx < len(self.arg_properties)
- else None
- )
- result.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
- return result
- else:
- return [V.graph.wrapper_code.val_to_arg_str(a) for a in self.constant_args]
- def codegen_args(self) -> list[str]:
- if V.graph.cpp_wrapper and self.op_overload is not None:
- # cpp wrapper needs special logic to fill in missing args with default values
- inputs = self.fill_non_provided_args(
- [*self.inputs, *self.constant_args], self.kwargs
- )
- # fill_non_provided_args has handled constant args, so no need to codegen for that later
- need_codegen_constant_args = False
- else:
- inputs = self.inputs
- need_codegen_constant_args = True
- args = []
- for i, x in enumerate(inputs):
- if V.graph.cpp_wrapper:
- assert self.arg_properties and i < len(self.arg_properties), (
- "Invalid access to ExternKernel.arg_properties"
- )
- type_ = self.arg_properties[i].get("type")
- args.append(V.graph.wrapper_code.val_to_arg_str(x, type_))
- else:
- args.append(V.graph.wrapper_code.val_to_arg_str(x))
- if need_codegen_constant_args:
- args.extend(self.codegen_const_args())
- return args
- def get_kwargs_value(self, arg_name: str, **kwargs: Any) -> Any:
- """Given an argument name, queries for values in (in order):
- 1. any provided kwargs for this function.
- 2. the class self.kwargs member.
- 3. any available default arguments in self.allarg_properties."""
- if arg_name in kwargs:
- return kwargs.get(arg_name)
- if arg_name in self.kwargs:
- return self.kwargs.get(arg_name)
- if (arg := self.allarg_properties.get(arg_name)) is not None:
- return arg.get("default_value")
- raise AssertionError(f"{arg_name} not in self.allarg_properties")
- def codegen_kwargs(self, skip_out: bool = False) -> list[str]:
- if V.graph.cpp_wrapper:
- if self.op_overload is not None and len(self.schema_kwargs) == 0:
- # All the args should have been generated by fill_non_provided_args in codegen_args
- return []
- kwargs = []
- for arg_name in self.ordered_kwargs_for_cpp_kernel:
- if skip_out and arg_name == "out":
- # ExternKernelOut has its own logic for inserting the out parameter
- continue
- v = self.get_kwargs_value(arg_name)
- if isinstance(v, Expr):
- kwargs.append(v)
- else:
- assert self.allarg_properties is not None
- type_ = self.allarg_properties.get(arg_name, {}).get("type")
- kwargs.append(V.graph.wrapper_code.val_to_arg_str(v, type_))
- else:
- kwargs = [
- f"{k}={V.graph.wrapper_code.val_to_arg_str(v)}"
- for k, v in self.kwargs.items()
- ]
- return kwargs
- def get_op_name(self) -> str:
- if self.fx_node is not None:
- target = self.fx_node.target
- op_namespace = getattr(target, "__module__", "unknown_namespace")
- op_namespace = op_namespace.replace("._ops.", ".ops.")
- op_namespace = op_namespace.rsplit(".", 1)[0]
- op_name = f"{op_namespace}.{target}"
- else:
- op_name = "unknown_op"
- return op_name
- def codegen_size_asserts(self, wrapper: PythonWrapperCodegen) -> None:
- if config.size_asserts and not V.graph.cpp_wrapper:
- # comparing strides for 0 size tensor is tricky. Ignore them for now.
- if sympy_product(self.get_size()) == 0:
- return
- size = V.graph.wrapper_code.codegen_shape_tuple(self.get_size())
- stride = V.graph.wrapper_code.codegen_shape_tuple(self.get_stride())
- op_name = self.get_op_name()
- wrapper.writeline(
- f"assert_size_stride({self.get_name()}, {size}, {stride}, {op_name!r})"
- )
- def codegen_alignment_asserts(self, wrapper: PythonWrapperCodegen) -> None:
- if config.alignment_asserts and not V.graph.cpp_wrapper:
- name = self.get_name()
- aligned = name not in V.graph.unaligned_buffers
- op_name = self.get_op_name()
- if aligned:
- wrapper.writeline(
- f"assert_alignment({name}, {GPU_ALIGN_BYTES}, {op_name!r})"
- )
- else:
- wrapper.writeline(
- f"# buffer {name} (op: {op_name}) is assumed to be not aligned"
- )
- def codegen_memory_tracking(self, wrapper: PythonWrapperCodegen) -> None:
- """
- Track outputs of fallback operators if config.test_configs.track_memory_lifecycle
- """
- if not config.test_configs.track_memory_lifecycle or V.graph.cpp_wrapper:
- return
- wrapper.write_memory_track_allocation_once()
- name = self.get_name()
- wrapper.writeline(f"track_tensor({name}, '{name}')")
- def get_group_stride(self) -> tuple[list[Sequence[Expr]], list[Expr]]:
- """
- get output sizes and strides, for template_codegen
- """
- _size = self.get_size()
- _stride = self.get_stride()
- # iter_ranges = _size of output tensor, reduce_range = [] because no reduction
- return [_size, []], _stride
- def canonicalize(self) -> tuple[Expr, Sequence[Expr]]:
- """
- Manually get canonicalization of the output index
- """
- # manually generate index formula for conv
- sizevars = V.graph.sizevars
- sizes = self.get_size()
- strides = self.get_stride()
- strides = [sizevars.size_hint(x) for x in strides]
- # TODO: I can't tell if the symbols here are temporary
- index_vars = [sympy_index_symbol(f"d{i}") for i in range(len(sizes))]
- # reorder index vars according to stride
- index_order = sorted(range(len(strides)), key=strides.__getitem__, reverse=True)
- lookup = {pos: idx for idx, pos in enumerate(index_order)}
- order = [lookup[i] for i in range(len(lookup))]
- index_vars = [index_vars[i] for i in order]
- indexer = self.make_indexer()
- index = indexer(index_vars)
- new_sizes, reindex, _prune = V.graph.sizevars._simplify_loops(
- index_vars, sizes, [index]
- )
- # assign new variables each dimension to deal with numbering mismatches
- # d0, d1, d2 could become d0, d2 -- which won't match d0, d1
- _, add_var = var_builder("c")
- replacement = dict(zip(index_vars, reindex([add_var(x) for x in new_sizes])))
- index = sympy_subs(sympy.expand(index), replacement)
- return index, tuple(new_sizes)
- @cache_on_self_and_args("ExternKernel")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- # NB: It's not necessary to check regular inputs as we automatically
- # have dependencies on them
- maybe_get_symbols = (
- maybe_free_unbacked_symbols if unbacked_only else maybe_free_symbols
- )
- r = InputsKernel.get_free_symbol_uses(self, unbacked_only)
- for arg in self.constant_args:
- r |= maybe_get_symbols(arg)
- for arg in self.kwargs.values():
- r |= maybe_get_symbols(arg)
- return r
- def __str__(self) -> str:
- kernel_name = getattr(self, "python_kernel_name", None)
- lines = [
- f"python_kernel_name={kernel_name!r}",
- ]
- lines += [
- f"{field.name}={getattr(self, field.name)}"
- for field in dataclasses.fields(self)
- ]
- lines.append(f"origin_node={self.origin_node!r}")
- return self.str_helper(lines)
- __repr__ = __str__
- @ir_dataclass(frozen=False)
- class ExternKernelOut(ExternKernel):
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.generate_extern_kernel_out(self)
- def __init__(
- self,
- layout: Layout,
- inputs: Sequence[IRNode],
- constant_args: Sequence[Any] = (),
- kwargs: Optional[dict[str, Any]] = None,
- output_view: Optional[ReinterpretView] = None,
- python_kernel_name: Optional[str] = None,
- cpp_kernel_name: Optional[str] = None,
- ordered_kwargs_for_cpp_kernel: Sequence[Any] = (),
- op_overload: Optional[_OpOverloads] = None,
- ) -> None:
- unwrapped_inputs = self.unwrap_storage(inputs)
- assert isinstance(unwrapped_inputs, Sequence), type(unwrapped_inputs)
- super().__init__(
- None,
- layout,
- unwrapped_inputs,
- constant_args,
- kwargs or {},
- None,
- python_kernel_name,
- cpp_kernel_name,
- ordered_kwargs_for_cpp_kernel,
- op_overload,
- )
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- def should_allocate(self) -> bool:
- return True
- class RandomSeeds(ExternKernelOut):
- def __init__(self, count: int, device: torch.device) -> None:
- limits = torch.iinfo(torch.int64)
- super().__init__(
- layout=FixedLayout(
- device=device,
- dtype=torch.int64,
- size=[count],
- ),
- inputs=[],
- constant_args=[limits.min, limits.max, [count]],
- python_kernel_name="aten.randint.low_out",
- # FIXME: Ideally we should only use at::_ops::randint_low_out::call here,
- # but the signature is different from is at::randint_out. Again,
- # we can simplify the code when only keeping an ABI-compatible version.
- cpp_kernel_name="at::_ops::randint_low_out::call",
- op_overload=aten.randint.low_out,
- )
- class ExternKernelAlloc(ExternKernel):
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.generate_extern_kernel_alloc(self)
- def __init__(
- self,
- layout: OutputSpec,
- inputs: Sequence[IRNode],
- constant_args: Sequence[Any] = (),
- kwargs: Optional[dict[str, Any]] = None,
- python_kernel_name: Optional[str] = None,
- cpp_kernel_name: Optional[str] = None,
- ordered_kwargs_for_cpp_kernel: Sequence[Any] = (),
- op_overload: Optional[_OpOverloads] = None,
- ) -> None:
- unwrapped_inputs = self.unwrap_storage(inputs)
- assert all(isinstance(i, IRNode) for i in unwrapped_inputs)
- super().__init__(
- None,
- layout,
- cast(Sequence[IRNode], unwrapped_inputs),
- constant_args,
- kwargs or {},
- None,
- python_kernel_name,
- cpp_kernel_name,
- ordered_kwargs_for_cpp_kernel,
- op_overload,
- )
- # We need output buffers for generating kernel arguments in the
- # abi-compatible mode, where we retrieve outputs by pass each individual
- # output through the abi-compatible interface.
- self.outputs: Sequence[Any] = []
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- def should_allocate(self) -> bool:
- return False
- def apply_constraint(self) -> None:
- raise NotImplementedError
- class MutationOutput(Buffer):
- """
- An output buffer that represents the mutation of a pre-existing buffer
- """
- def __init__(
- self, layout: OutputSpec, mutated_node: IRNode, mutating_node: Operation
- ) -> None:
- super().__init__(name=None, layout=layout)
- mutated_node_name = mutated_node.get_name()
- V.graph.mark_buffer_mutated(mutated_node_name)
- self.mutation_names = [mutated_node_name]
- self.mutating_node: Operation = mutating_node
- self.name = V.graph.register_buffer(self)
- def get_defining_op(self) -> Operation:
- return self.mutating_node
- def get_mutation_names(self) -> Sequence[str]:
- return self.mutation_names
- def should_allocate(self) -> bool:
- return False
- def get_mutation_buffers(self) -> Sequence[IRNode]:
- mutation_names = self.get_mutation_names()
- return [
- buf
- for buf in (V.graph.try_get_buffer(name) for name in mutation_names)
- if buf is not None
- ]
- class TMADescriptor(ExternKernel):
- """
- An IR node representing a generic host-side TMA descriptor in the Triton API
- Mostly useful for user-defined Triton kernels relying on host-side TMA;
- but can, in principle, be used for Inductor's Triton templates, too.
- See TMADescriptorExperimental and TMADescriptorStable for the two implementations
- (the old API and the new API)
- """
- # as TMA descriptors are immutable,
- # we can dedup them by the input args
- _CACHE: dict[Any, TMADescriptor] = {}
- @classmethod
- def _create_impl(
- cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]]
- ) -> TMADescriptor:
- assert len(tma_meta) == 2
- if tma_meta[0] == "experimental":
- return TMADescriptorExperimental(tensor, *tma_meta[1])
- else:
- assert tma_meta[0] == "stable"
- return TMADescriptorStable(tensor, *tma_meta[1])
- @classmethod
- def create(
- cls, tensor: IRNode, tma_meta: tuple[str, tuple[Any, ...]]
- ) -> TMADescriptor:
- key = (id(tensor), tma_meta)
- if key not in cls._CACHE:
- cls._CACHE[key] = cls._create_impl(tensor, tma_meta)
- return cls._CACHE[key]
- def __init__(
- self, tensor: IRNode, inputs: Sequence[Any], constant_args: Sequence[Any]
- ) -> None:
- super().__init__(
- None,
- # link back to the underlying tensor in terms of ownership
- # to avoid getting the underlying tensor deleted *before*
- # the TMADescriptor node can be deleted.
- NonOwningLayout(
- ReinterpretView(
- data=tensor,
- layout=tensor.get_layout(),
- )
- ),
- cast(Sequence[Buffer], inputs),
- tuple(constant_args),
- None,
- )
- self.tensor = tensor
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.generate_tma_descriptor(self)
- def get_tensor(self) -> IRNode:
- return self.tensor
- class TMADescriptorExperimental(TMADescriptor):
- """
- the new host-side TMA Descriptor API:
- (the ones obtained via create_{1d,2d}_tma_descriptor calls).
- See also TMADescriptorStable for the new API.
- """
- def __init__(
- self,
- tensor: IRNode,
- dims: list[Union[int, torch.SymInt]],
- block_dims: list[Union[int, torch.SymInt]],
- element_size: Optional[int] = None,
- ) -> None:
- assert len(dims) in (1, 2)
- assert len(dims) == len(block_dims)
- if element_size is None:
- element_size = tensor.get_dtype().itemsize
- self.dims = dims
- self.block_dims = block_dims
- self.element_size = element_size
- self.rank = len(self.dims)
- inputs = [tensor]
- constant_args = [
- *self.dims,
- *self.block_dims,
- self.element_size,
- ]
- super().__init__(
- tensor=tensor,
- inputs=inputs,
- constant_args=constant_args,
- )
- class TMADescriptorStable(TMADescriptor):
- """
- the new host-side TMA descriptor API
- (the ones obtained via TensorDescriptor.from_tensor).
- See also TMADescriptorExperimental for the old API.
- """
- def __init__(self, tensor: IRNode, block_shape: list[Union[int, torch.SymInt]]):
- self.block_shape = block_shape
- super().__init__(
- tensor=tensor,
- inputs=[tensor],
- constant_args=block_shape,
- )
- class SubgraphBuffer(ExternKernel):
- def __init__(
- self,
- layout: Layout,
- input_nodes: list[Buffer],
- gm: torch.fx.GraphModule,
- example_inputs: list[Any],
- subgraph_name: str,
- ):
- super().__init__(None, layout, input_nodes)
- self.gm = gm
- self.example_inputs = example_inputs
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- self.subgraph = V.graph.make_subgraph(self.gm, example_inputs, subgraph_name)
- assert is_node_sequence(self.inputs)
- sym_inputs = get_symbolic_inputs(self.inputs)
- for sym_inp in sym_inputs:
- self.subgraph.graph_inputs[sym_inp.name] = sym_inp
- self.subgraph.graph_input_names.append(sym_inp.name)
- self.sym_inputs = [sym_var.name for sym_var in sym_inputs]
- import torch._inductor.config as inductor_config
- with V.set_graph_handler(self.subgraph):
- # Don't bother autotuning on Triton here
- with inductor_config.patch(
- max_autotune=False,
- max_autotune_gemm=False,
- max_autotune_gemm_backends="ATEN",
- ):
- self.subgraph.run(*self.example_inputs)
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- class CodegenGraph:
- def __init__(self, graph: GraphLowering):
- self.graph = graph
- self.name = graph.name
- assert is_node_sequence(self.inputs)
- outer_inputs = [t.codegen_reference() for t in self.inputs]
- wrapper.codegen_subgraph_with_flattened_outputs(
- CodegenGraph(self.subgraph),
- [*self.sym_inputs, *outer_inputs],
- [self.name],
- )
- class UserDefinedTritonKernel(ExternKernel):
- def get_kernel_and_metadata(self) -> tuple[Kernel, Any, list[str], list[str]]:
- from triton.runtime.autotuner import Autotuner
- from torch._higher_order_ops.triton_kernel_wrap import kernel_side_table
- kernel = kernel_side_table.get_kernel(self.kernel_idx)
- configs = []
- restore_value_args: list[str] = []
- reset_to_zero_args: list[str] = []
- if isinstance(kernel, Autotuner):
- # https://github.com/triton-lang/triton/pull/5083
- # changes kernel.restore_idx to kernel.restore_value
- if hasattr(kernel, "restore_idx"):
- restore_value_args.extend(
- kernel.fn.arg_names[i] for i in kernel.restore_idx
- )
- else:
- assert hasattr(kernel, "restore_value")
- restore_value_args.extend(kernel.restore_value)
- if hasattr(kernel, "reset_idx"):
- for i in kernel.reset_idx:
- reset_to_zero_args.append(kernel.fn.arg_names[i])
- else:
- assert hasattr(kernel, "reset_to_zero")
- reset_to_zero_args.extend(kernel.reset_to_zero)
- configs = kernel.configs
- kernel = kernel.fn
- return kernel, configs, restore_value_args, reset_to_zero_args
- @override
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- """Overrides the parent member.
- See https://github.com/pytorch/pytorch/issues/151692"""
- from torch._inductor.utils import triton_version_uses_attrs_dict
- (
- kernel,
- configs,
- restore_value_args,
- reset_to_zero_args,
- ) = self.get_kernel_and_metadata()
- # Definition of kernel
- (
- new_name,
- triton_meta,
- extra_launch_args,
- ) = wrapper.define_user_defined_triton_kernel(
- kernel,
- configs,
- self.kwargs,
- restore_value_args,
- reset_to_zero_args,
- self.grid,
- )
- named_args = {
- k: self.get_kwargs_value(k) for k in self.ordered_kwargs_for_cpp_kernel
- }
- assert hasattr(kernel, "arg_names") and hasattr(kernel, "constexprs"), type(
- kernel
- )
- constexpr_names = OrderedSet(kernel.arg_names[i] for i in kernel.constexprs)
- args: list[Any] = []
- arg_types: list[Any] = []
- raw_keys_filtered: list[Any] = []
- raw_args_filtered: list[Any] = []
- for name, arg in itertools.chain(
- named_args.items(), zip(itertools.repeat(""), extra_launch_args)
- ):
- if name in constexpr_names and triton_version_uses_attrs_dict():
- # see #160000 - we don't pass in constexpr args to speed up runtime.
- continue
- raw_keys_filtered.append(name)
- raw_args_filtered.append(arg)
- if isinstance(arg, IRNode):
- args.append(arg.codegen_reference())
- arg_types.append(arg.get_dtype())
- elif isinstance(arg, (int, float, bool, sympy.Expr)):
- args.append(arg)
- arg_types.append(type(arg))
- elif name in constexpr_names:
- # insert a dummy value for constexpr args of unsupported type
- # constexprs will end up getting baked into the kernel at compile time
- args.append(-1)
- arg_types.append(int)
- elif arg is None:
- """
- Filter out None args.
- see https://github.com/pytorch/pytorch/issues/115344
- Two cases for a None arg:
- 1. The arg is already tl.constexpr, so leave it in
- 2. The arg is not tl.constexpr so we have to remove it
- """
- if triton_version_uses_attrs_dict():
- args.append(-1)
- arg_types.append(int)
- else:
- raw_keys_filtered.pop()
- raw_args_filtered.pop()
- else:
- raise NotImplementedError(f"Unsupported arg type: {type(arg)}: {arg}")
- self.codegen_comment(wrapper)
- wrapper.generate_kernel_call(
- new_name,
- args,
- arg_types=arg_types,
- raw_args=raw_args_filtered,
- raw_keys=raw_keys_filtered,
- triton_meta=triton_meta,
- triton=True,
- device=self.get_device(),
- original_fxnode_name=self.fx_node.name,
- )
- @cache_on_self_and_args("UserDefinedTritonKernel")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- # add unbacked symbols used in the grid to the ones used
- # in the kwargs (the latter is generated by ExternKernel)
- return super().get_free_symbol_uses(unbacked_only) | get_free_symbols(
- self.grid, unbacked_only
- )
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def __init__(
- self,
- *,
- kernel_idx: int,
- grid: Any,
- tma_descriptor_metadata: dict[str, Any],
- kernel_args: dict[str, Any],
- ) -> None:
- inputs: list[IRNode] = []
- kwargs: dict[str, IRNode] = {}
- constant_args: list[IRNode] = []
- for k, v in kernel_args.items():
- if isinstance(v, TensorBox):
- t = InputsKernel.unwrap_storage_for_input(self.realize_input(v))
- if k in tma_descriptor_metadata:
- t = TMADescriptor.create(t, tma_descriptor_metadata[k])
- inputs.append(t)
- kwargs[k] = t
- else:
- constant_args.append(v)
- kwargs[k] = v
- assert len(inputs) != 0
- self.device = inputs[0].get_device()
- assert isinstance(inputs, Sequence), type(inputs)
- super().__init__(
- None,
- NoneLayout(device=self.device),
- inputs,
- tuple(constant_args),
- kwargs,
- )
- self.kernel_idx = kernel_idx
- self.grid = grid
- kernel, configs, _, _ = self.get_kernel_and_metadata()
- # If we are autotuning, not all arguments will be passed
- assert hasattr(kernel, "arg_names")
- self.ordered_kwargs_for_cpp_kernel = [
- arg for arg in kernel.arg_names if arg in kernel_args
- ]
- from torch._higher_order_ops.triton_kernel_wrap import identify_mutated_tensors
- autotuned_kwargs = configs[0].kwargs if len(configs) > 0 else {}
- self.mutable_args = [
- kernel_args[key]
- for key in identify_mutated_tensors(
- kernel, {**kernel_args, **autotuned_kwargs}, tma_descriptor_metadata
- )
- ]
- self.mutation_outputs = [
- MutationOutput(NoneLayout(device=self.device), buf, self)
- for buf in self.mutable_args
- ]
- V.graph.register_operation(self)
- def get_outputs(self) -> list[Buffer]:
- return list(self.mutation_outputs)
- def get_device(self) -> Optional[torch.device]:
- return self.device
- class InplaceBernoulliFallback(ExternKernel):
- """
- This needs to be a custom class to handle mutation properly
- """
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- assert all(isinstance(t, IRNode) for t in self.inputs)
- (x,) = (cast(IRNode, t).codegen_reference() for t in self.inputs)
- if V.graph.cpp_wrapper:
- # Inductor doesn't really support aten Generator, so the Generator kwarg is always NULL here,
- # which needs to be explicitly generated for cpp wrapper
- wrapper.writeline(
- f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}, NULL){wrapper.ending}"
- )
- else:
- wrapper.writeline(
- f"{self.get_kernel_name()}({x}, {', '.join(map(repr, self.constant_args))}){wrapper.ending}"
- )
- def should_allocate(self) -> bool:
- return False
- def get_mutation_names(self) -> Sequence[str]:
- return [self.input_name(0)]
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def __init__(
- self, op_overload: _OpOverloads, x: IRNode, *constant_args: Any
- ) -> None:
- super().__init__(
- None,
- NoneLayout(device=x.get_device()),
- self.unwrap_storage([x]),
- constant_args,
- op_overload=op_overload,
- )
- V.graph.mark_buffer_mutated(x.get_name())
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- # Used to deal with torch.complex types
- class InplaceCopyFallback(ExternKernel):
- """
- This needs to be a custom class to handle mutation properly
- """
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- (dst, src, non_blocking) = self.codegen_args()
- wrapper.codegen_device_copy(src, dst, non_blocking)
- def should_allocate(self) -> bool:
- return False
- def get_mutation_names(self) -> Sequence[str]:
- return [self.input_name(0)]
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def __init__(
- self,
- layout: OutputSpec,
- inputs: Sequence[IRNode],
- constant_args: Sequence[Any],
- ) -> None:
- super().__init__(
- None,
- layout,
- inputs,
- constant_args,
- python_kernel_name="aten.copy_",
- cpp_kernel_name="aoti_torch_copy_",
- )
- V.graph.mark_buffer_mutated(inputs[0].get_name())
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- @classmethod
- def create(
- cls, dst: IRNode, src: IRNode, non_blocking: bool = False
- ) -> InplaceCopyFallback:
- inputs = [cls.realize_input(t) for t in [dst, src]]
- constant_args = (non_blocking,)
- result = InplaceCopyFallback(
- NoneLayout(device=dst.get_device()),
- inputs,
- constant_args,
- )
- return result
- class MutatingFirstArgExternKernel(ExternKernel):
- """
- This needs to be a custom class to handle mutation properly
- """
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- assert is_node_sequence(self.inputs)
- argrefs = [
- *(t.codegen_reference() for t in self.inputs),
- *map(repr, self.constant_args),
- ]
- wrapper.writeline(
- f"{self.get_kernel_name()}({', '.join(argrefs)}){wrapper.ending}"
- )
- def should_allocate(self) -> bool:
- return False
- def get_mutation_names(self) -> Sequence[str]:
- return [self.input_name(0)]
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def has_side_effects(self) -> bool:
- return True
- class ResizeStorageBytes(MutatingFirstArgExternKernel):
- def __init__(self, variable: IRNode, new_size: int) -> None:
- assert isinstance(new_size, int), "TODO: dynamic shapes"
- super().__init__(
- None,
- NoneLayout(device=variable.get_device()),
- self.unwrap_storage([variable]),
- constant_args=(new_size,),
- )
- V.graph.mark_buffer_mutated(variable.get_name())
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- self.python_kernel_name = "inductor_ops.resize_storage_bytes_"
- self.cpp_kernel_name = "torch::inductor::resize_storage_bytes_"
- assert isinstance(variable, (BaseView, StorageBox, TensorBox)), type(variable)
- V.graph.never_reuse_buffers.add(variable.data.get_name())
- class SetSourceTensorKernel(ExternKernelAlloc):
- def __init__(self, self_tensor: IRNode, storage_tensor: IRNode) -> None:
- storage_tensor.freeze_layout()
- super().__init__(
- storage_tensor.get_layout(),
- [self_tensor, storage_tensor],
- python_kernel_name="torch.ops.aten.set_.source_Tensor",
- op_overload=torch.ops.aten.set_.source_Tensor,
- )
- assert isinstance(self_tensor, (BaseView, StorageBox, TensorBox)), type(
- self_tensor
- )
- V.graph.never_reuse_buffers.add(self_tensor.data.get_name())
- V.graph.never_reuse_buffers.add(storage_tensor.get_name())
- V.graph.never_reuse_buffers.add(self.get_name())
- device = storage_tensor.get_device()
- self.mutation_outputs = [
- MutationOutput(NoneLayout(device=device), self_tensor, self),
- MutationOutput(NoneLayout(device=device), storage_tensor, self),
- ]
- def get_inputs_that_alias_output(self) -> Sequence[str]:
- return [self.input_name(0), self.input_name(1)]
- class ScatterFallback(ExternKernel):
- """
- This needs to be a custom class to handle mutation properly.
- This class handles both aten.scatter_ and aten.scatter_reduce_.
- It also handle the case `src` being a scalar properly.
- """
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- reduce = self.kwargs["reduce"]
- if V.graph.cpp_wrapper:
- # Follow aten/src/ATen/native/ReductionType.h:get_operator_enum
- get_operator_enum = {"add": "sum", "multiply": "prod"}
- if reduce in get_operator_enum:
- reduce = get_operator_enum[reduce]
- assert is_node_sequence(self.inputs)
- if self.src_is_tensor:
- (x, index, src) = (t.codegen_reference() for t in self.inputs)
- else:
- (x, index) = (t.codegen_reference() for t in self.inputs)
- src = self.constant_args[1]
- wrapper.generate_scatter_fallback(
- x,
- [x, self.constant_args[0], index, src],
- self.cpp_kernel_name,
- self.python_kernel_name,
- self.src_is_tensor,
- reduce,
- self.codegen_kwargs(),
- )
- def should_allocate(self) -> bool:
- return False
- def get_mutation_names(self) -> list[str]:
- inp = self.inputs[0]
- assert isinstance(inp, IRNode)
- return [inp.get_name()]
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def __init__(
- self,
- op_overload: _OpOverloads,
- x: IRNode,
- dim: int,
- index: IRNode,
- src: IRNode,
- *,
- reduce: Optional[str] = None,
- include_self: bool = True,
- ) -> None:
- self.src_is_tensor = isinstance(src, TensorBox)
- constant_args: tuple[Any, ...]
- if self.src_is_tensor:
- tensors = [self.realize_input(t) for t in [x, index, src]]
- constant_args = (dim,)
- else:
- tensors = [self.realize_input(t) for t in [x, index]]
- constant_args = (dim, src)
- super().__init__(
- None,
- NoneLayout(device=x.get_device()),
- self.unwrap_storage(tensors),
- constant_args,
- {"reduce": reduce, "include_self": include_self},
- python_kernel_name=str(op_overload),
- ordered_kwargs_for_cpp_kernel=["reduce", "include_self"],
- op_overload=op_overload,
- )
- V.graph.mark_buffer_mutated(x.get_name())
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- class IndexPutFallback(ExternKernel):
- """
- This needs to be a custom class to handle mutation and indices properly
- """
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- assert is_node_sequence(self.inputs)
- (x, values, *valid_indices) = (t.codegen_reference() for t in self.inputs)
- indices = []
- iter_valid_indices = iter(valid_indices)
- for i, _ in enumerate(self.indices):
- if self.indices[i] is not None:
- indices.append(next(iter_valid_indices))
- else:
- indices.append(V.graph.wrapper_code.none_str)
- wrapper.generate_index_put_fallback(
- self.get_kernel_name(), x, indices, values, *self.codegen_const_args()
- )
- def should_allocate(self) -> bool:
- return False
- def get_mutation_names(self) -> Sequence[str]:
- return [self.input_name(0)]
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- def __init__(
- self,
- op_overload: torch._ops.OpOverload,
- x: IRNode,
- indices: list[Any],
- values: Sequence[Any],
- accumulate: Any,
- ) -> None:
- self.indices = indices
- valid_indices = [i for i in indices if i is not None]
- tensors = [self.realize_input(x) for x in [x, values, *valid_indices]]
- cpp_kernel_name = "aoti_torch_index_put_out"
- super().__init__(
- None,
- NoneLayout(device=x.get_device()),
- self.unwrap_storage(tensors),
- (accumulate,),
- python_kernel_name="aten.index_put_",
- cpp_kernel_name=cpp_kernel_name,
- op_overload=op_overload,
- )
- V.graph.mark_buffer_mutated(self.input_name(0))
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- class DeviceCopy(ExternKernelOut):
- @classmethod
- def create(cls, x: IRNode, device: torch.device, non_blocking: bool) -> IRNode:
- if (
- not x.is_extern()
- and all(r in V.graph.constants for r in x.get_read_names())
- and not config.aot_inductor.use_runtime_constant_folding
- ):
- return x.constant_to_device(device)
- V.graph.add_device_info(device)
- x_device = x.get_device()
- assert x_device is not None
- V.graph.add_device_info(x_device)
- developer_warning("DeviceCopy in input program")
- constant_args = (non_blocking,)
- # Device Copy should keep the same layout as input
- x = ExternKernel.require_contiguous(x)
- stride = None
- if x.get_size():
- # x.get_stride() may be unimplemented if x's size is empty
- stride = x.get_stride()
- is_destination_pinned = (
- is_gpu(x_device.type) and device.type == "cpu" and non_blocking
- )
- is_source_pinned = (
- x_device.type == "cpu" and is_gpu(device.type) and non_blocking
- )
- if is_source_pinned and is_storage_and_layout(x):
- x.get_layout().is_pinned = True
- return DeviceCopy(
- FixedLayout(
- device,
- x.get_dtype(),
- x.get_size(),
- stride,
- is_pinned=is_destination_pinned,
- ),
- [cls.realize_input(x)],
- constant_args,
- )
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- args = self.codegen_args()
- assert len(args) == 2
- if self.output_view:
- wrapper.codegen_device_copy(
- args[0], self.output_view.codegen_reference(), args[1]
- )
- else:
- wrapper.codegen_device_copy(args[0], self.codegen_reference(), args[1])
- class DynamicSelectStorageOffset(ExternKernel):
- """
- The result of computing a dynamic selection index is determined as follows: when the index in the
- select operation is unbacked, the actual index calculation is ambiguous for negative indices
- (index + size) versus non-negative indices (just index). To resolve this, we allocate an unbacked
- SymInt to represent the storage offset and decompose the select operation into a call to as_strided,
- computing the storage offset at runtime with this node.
- """
- def get_reads(self) -> OrderedSet[Dep]:
- return OrderedSet()
- def should_allocate(self) -> bool:
- return False
- def __init__(
- self,
- unbacked_offset_symbol: sympy.Symbol,
- index: sympy.Symbol,
- base_offset: Union[sympy.Symbol, int],
- base_dim_stride: Union[sympy.Symbol, int],
- size: Union[sympy.Symbol, int],
- ) -> None:
- super().__init__(None, NoneLayout(device=torch.device("cpu")), [])
- # This node codegen the following:
- # unbacked_offset_symbol = base_offset + base_dim_stride * (index if index >=0 else index + size)
- self.unbacked_offset_symbol = unbacked_offset_symbol
- self.index = index
- self.base_offset = base_offset
- self.base_dim_stride = base_dim_stride
- self.size = size
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet([self.unbacked_offset_symbol])
- @cache_on_self_and_args("DynamicSelectStorageOffset")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return get_free_symbols(self.index, unbacked_only)
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.codegen_dynamic_select_index(self)
- class DynamicScalar(ExternKernel):
- """
- The result of a call to aten._local_scalar_dense.
- """
- def get_reads(self) -> OrderedSet[Dep]:
- return OrderedSet()
- def should_allocate(self) -> bool:
- return False
- def __init__(
- self, sym: sympy.Symbol, keypath: pytree.KeyPath, data: IRNode
- ) -> None:
- data.realize()
- super().__init__(
- None, NoneLayout(device=torch.device("cpu")), self.unwrap_storage([data])
- )
- self.sym = sym
- self.keypath = keypath
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- return OrderedSet([self.sym])
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.codegen_dynamic_scalar(self)
- class AssertScalar(ExternKernel):
- """
- The result of a call to aten._assert_scalar
- """
- def get_reads(self) -> OrderedSet[Dep]:
- return OrderedSet()
- def should_allocate(self) -> bool:
- return False
- def __init__(self, scalar: SympyBoolean, msg: str) -> None:
- super().__init__(
- # Buffer(name, layotu)
- None,
- NoneLayout(device=torch.device("cpu")),
- # InputsKernel(inputs)
- [],
- )
- self.scalar = scalar
- self.msg = msg
- def has_side_effects(self) -> bool:
- return True
- @cache_on_self_and_args("AssertScalar")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return get_free_symbols(self.scalar, unbacked_only)
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- if not config.scalar_asserts:
- return
- # NB: It is EXTREMELY important not to simplify the scalar under assertion here,
- # because simplify is done with respect to runtime asserts. So if you have
- # "u0 == 0" in the runtime asserts, if you subsequently try to
- # simplify(u0 == 0), you will get True (because we've already runtime assert'ed
- # that it's true). But we're code generating the actual runtime assert here!!
- symbol = next(iter(self.get_free_symbol_uses(unbacked_only=False)))
- if V.graph.fx_wrapper:
- # TODO fix
- pass
- elif V.graph.cpp_wrapper:
- symbol_str = f"std::to_string({symbol})"
- sizevar = V.graph.wrapper_code.codegen_cpp_sizevar(
- self.scalar, simplify=False
- )
- # TODO: when we start compiling in C++20, annotate with [[unlikely]].
- wrapper.writeline(
- f'if (!({sizevar})) {{ throw std::runtime_error("Expected {self.msg} but received " + {symbol_str}); }}'
- )
- else:
- sizevar = V.graph.wrapper_code.codegen_python_sizevar(
- self.scalar, simplify=False
- )
- wrapper.writeline(f"if not ({sizevar}):")
- wrapper.writeline(f" raise RuntimeError({repr(self.msg)})")
- # No one should ever use this buffer, but for uniformity
- # define the variable and assign it None
- wrapper.writeline(f"{self.get_name()} = None")
- @ir_dataclass(frozen=False)
- class ExternKernelNode:
- name: str
- node: export_schema.Node
- class FallbackKernel(ExternKernelAlloc):
- """
- A class that represents a fallback kernel for handling operators that are not
- directly support by inductor. It currently supports functional ops, view ops,
- inplace aten ops, and mutating ops that are auto-functionalizable.
- """
- def __init__(
- self,
- layout: OutputSpec,
- kernel: _OpOverloads,
- tensor_args: Sequence[IRNode],
- nontensor_args: Sequence[Any],
- unflatten_args: Callable[..., Any],
- kwargs: Optional[dict[str, Any]] = None,
- *,
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
- ) -> None:
- super().__init__(
- layout,
- tuple(tensor_args),
- tuple(nontensor_args),
- op_overload=kernel,
- )
- self.use_runtime_dispatch = False
- self.unbacked_bindings = unbacked_bindings or {}
- assert isinstance(
- kernel, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
- ), f"Fails to create FallbackKernel for {kernel}: {type(kernel)} not supported"
- self.op_overload = kernel
- self.unflatten_args = unflatten_args
- self.kwargs = {} if kwargs is None else kwargs
- assert self.python_kernel_name is not None
- V.graph.warn_fallback(self.python_kernel_name)
- # args that are aliased
- self.alias_names: list[str] = []
- # args that are mutated AND returned from the op
- self.mutation_names: list[str] = []
- if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
- # We assume here that HOPs with FallbackKernel are functional.
- # This may not always be true! HOPs must individually opt-in to
- # FallbackKernel, so please check this if you opt-in.
- return
- if "_c10d_functional" in self.op_overload.name():
- # _c10d_functional kernels are lowered into _CollectiveKernel which
- # derives from FallbackKernel for the cpp codegen. The kernels
- # don't pass the can_auto_functionalize check, but their mutation
- # is handled properly by _CollectiveKernel.
- return
- schema = self.op_overload._schema
- # NOTE: [FallbackKernel supported operators]
- # We only support three types of operators:
- # - functional ops
- # - view ops
- # - inplace aten ops
- # - mutating ops that are auto-functionalizable. That is,
- # the operator may mutate any number of inputs, but its outputs
- # may not alias any of the inputs.
- #
- # The unsupported cases usually do not show up here (because
- # AOTAutograd functionalized them away); the only way for an in-place
- # op to show up here is if a lowering or pass introduced it.
- if torch._library.utils.mutates_and_returns_first_arg(self.op_overload):
- self.mutation_names.append(tensor_args[0].get_name())
- return
- if schema.is_mutable and not can_auto_functionalize(kernel):
- raise NotImplementedError(
- f"NYI: Can't generate FallbackKernel for {kernel}"
- )
- args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
- def handle_aliasing_and_mutation(info: torch._C.Argument, arg: Any) -> None:
- # Assertions to make sure we didn't mismatch args
- if isinstance(info.type, torch.ListType):
- assert isinstance(arg, (list, tuple)), type(arg)
- if library_utils.is_tensor_like_type(info.type):
- # PyTorch also accepts None and scalar types for args marked as "Tensor".
- # We're not going to check all of them here.
- assert not isinstance(arg, (tuple, list))
- if arg is None:
- return
- if info.alias_info is None:
- return
- def add_alias(t: IRNode) -> None:
- self.alias_names.append(t.get_name())
- assert info.alias_info is not None
- if info.alias_info.is_write:
- self.mutation_outputs.append(
- MutationOutput(NoneLayout(device=t.get_device()), t, self)
- )
- if library_utils.is_tensorlist_like_type(info.type):
- if arg is not None:
- for optional_tensor_arg in arg:
- add_alias(optional_tensor_arg)
- else:
- assert library_utils.is_tensor_like_type(info.type)
- add_alias(arg)
- for info, arg in torch._library.utils.zip_schema(schema, args, kwargs):
- handle_aliasing_and_mutation(info, arg)
- def get_read_writes(self) -> dependencies.ReadWrites:
- read_writes = super().get_read_writes()
- if self.op_overload is torch._prims.rng_prims.graphsafe_run_with_rng_state:
- for arg in self.constant_args:
- if isinstance(arg, GeneratorState):
- read_writes = read_writes.with_read(
- dependencies.StarDep(arg.get_name())
- )
- return read_writes
- def codegen_unbacked_symbol_defs(self, wrapper: PythonWrapperCodegen) -> None:
- return wrapper.codegen_unbacked_symbol_defs_for_outputs(
- self.get_name(), self.outputs, getattr(self, "unbacked_bindings", None)
- )
- def get_unbacked_symbol_defs(self) -> Container[sympy.Symbol]: # type: ignore[override]
- if unbacked_bindings := getattr(self, "unbacked_bindings", None):
- resolved = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env, unbacked_bindings
- )
- assert resolved is not None
- return resolved.keys()
- else:
- return OrderedSet()
- def codegen_args(self) -> list[str]:
- @dataclasses.dataclass
- class Shim:
- ref: Any
- def __repr__(self) -> str:
- return self.ref
- assert is_node_sequence(self.inputs)
- tensor_args = [Shim(x.codegen_reference()) for x in self.inputs]
- args, kwargs = self.unflatten_args(tensor_args, self.constant_args)
- if V.graph.cpp_wrapper and isinstance(self.op_overload, torch._ops.OpOverload):
- args = self.fill_non_provided_args(args, kwargs)
- args = [
- V.graph.wrapper_code.val_to_arg_str(x, param.real_type)
- for param, x in zip(self.op_overload._schema.arguments, args)
- ]
- else:
- args = [V.graph.wrapper_code.val_to_arg_str(x) for x in args]
- # let self.codegen_kwargs handle kwargs
- self.kwargs.update(kwargs)
- return args
- @staticmethod
- def find_device(
- tensor_args: Optional[Sequence[torch.Tensor]], example_output: Sequence[Any]
- ) -> Any:
- non_torch_bind_tensor_args = (
- [t for t in tensor_args if not isinstance(t, TorchBindObject)]
- if tensor_args
- else None
- )
- if non_torch_bind_tensor_args:
- assert tensor_args
- devices = [arg.get_device() for arg in tensor_args if arg.get_device()]
- return devices[0]
- if isinstance(example_output, torch.Tensor):
- return example_output.device
- if isinstance(example_output, (list, tuple)):
- device_set = OrderedSet(
- FallbackKernel.find_device(None, x) for x in example_output
- )
- # Remove None
- devices = [device for device in device_set if device]
- if len(devices) == 1:
- return devices[0]
- for device in devices:
- assert isinstance(device, torch.device)
- if is_gpu(device.type):
- return device
- return devices[0]
- return None
- def has_side_effects(self) -> bool:
- if isinstance(self.op_overload, torch._ops.HigherOrderOperator):
- return False
- return get_schema_info(self.op_overload).is_mutable()
- def get_inputs_that_alias_output(self) -> Sequence[str]:
- assert isinstance(
- self.op_overload, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)
- ), (
- f"Fails to create FallbackKernel for {self.op_overload}: "
- f"{type(self.op_overload)} not supported"
- )
- # See [Note: FallbackKernel supported operators]: for a mutating
- # op that is auto-functionalizable, its outputs does NOT
- # alias any of the inputs.
- if (
- not isinstance(self.op_overload, torch._ops.HigherOrderOperator)
- and "_c10d_functional" not in self.op_overload.name()
- and self.op_overload._schema.is_mutable
- and can_auto_functionalize(self.op_overload)
- ):
- return []
- else:
- return self.alias_names
- def get_mutation_names(self) -> Sequence[str]:
- assert len(self.mutation_names) <= 1
- return self.mutation_names
- def export_extern_kernel_node(self): # type: ignore[no-untyped-def]
- """
- ProxyExecutor Design Note
- We export the ExternFallbackNodes (for custom ops) into a serialized file
- and run it with a host side proxy executor to address the ABI problem
- This is currently only implemented for fbcode. Eventually, we will also make this work for OSS.
- Detailed design doc can be found at
- https://docs.google.com/document/d/1wC4DOZFaYym2t1Esz0X5yxlLI3RDnSiyRbUus3bkJ64/edit?usp=sharing
- """
- log.debug(
- "Extern kernel node added for node %s with target %s.",
- self.get_name(),
- self.op_overload,
- )
- assert isinstance(self, FallbackKernel), type(self)
- args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
- args = self.fill_non_provided_args(args, kwargs)
- ordered_kwargs = [
- self.get_kwargs_value(key, **kwargs)
- for key in self.ordered_kwargs_for_cpp_kernel
- ]
- target = self.op_overload
- if not V.graph.aot_mode:
- # No need to serialize in the cpp wrapper JIT mode
- return [*args, *ordered_kwargs]
- serializer = GraphModuleSerializer(None, []) # type: ignore[arg-type]
- named_arguments = serializer.serialize_inputs(target, args, kwargs)
- # serialize_outputs
- def handle_single_output(
- return_type: Union[torch.TensorType, torch.ListType, torch.JitType],
- output: Union[IRNode, Sequence[IRNode]],
- ) -> export_schema.Argument:
- if isinstance(return_type, (torch.TensorType, torch.NoneType)):
- # For single Tensor or None
- out = output
- if isinstance(output, (list, tuple)):
- assert len(output) == 1
- out = output[0]
- if isinstance(return_type, torch.TensorType):
- assert isinstance(out, IRNode)
- return export_schema.Argument.create(
- as_tensor=export_schema.TensorArgument(name=out.get_name())
- )
- else: # NoneType
- assert out is None
- return export_schema.Argument.create(as_none=True)
- elif isinstance(return_type, torch.ListType) and isinstance(
- return_type.getElementType(), torch.TensorType
- ):
- assert isinstance(output, Sequence), type(output)
- # For single TensorList
- return export_schema.Argument.create(
- as_tensors=[
- export_schema.TensorArgument(name=out.get_name())
- for out in output
- ]
- )
- elif isinstance(return_type, torch.OptionalType) and isinstance(
- return_type.getElementType(), torch.TensorType
- ):
- # For OptionalTensor
- if output is None:
- return export_schema.Argument.create(
- as_optional_tensor=export_schema.OptionalTensorArgument.create(
- as_none=True
- )
- )
- else:
- assert isinstance(output, IRNode)
- return export_schema.Argument.create(
- as_optional_tensor=export_schema.OptionalTensorArgument.create(
- as_tensor=export_schema.TensorArgument(
- name=output.get_name()
- )
- )
- )
- elif isinstance(return_type, torch.IntType):
- return export_schema.Argument.create(as_int=output)
- else:
- raise RuntimeError(f"Unsupported return type {type(return_type)}")
- if isinstance(target, torch._higher_order_ops.torchbind.CallTorchBind):
- returns = target.schema(args[0], args[1]).returns
- else:
- returns = target._schema.returns # type: ignore[union-attr]
- if len(returns) == 1:
- # NOTE: [special handling of all_reduce_coalesced_'s return value]
- # all_reduce_coalesced_ return a list of tensors via self.mutation_outputs
- outputs = self.outputs if self.outputs else self.mutation_outputs
- return_type = returns[0].real_type
- output_arguments = [handle_single_output(return_type, outputs)]
- else:
- # For tuple returns, e.g "-> (Tensor, Tensor)" or "-> (Tesnor, Tensor[])"
- # Not generating output args for self.mutation_outputs
- output_arguments = [
- handle_single_output(
- return_schema.real_type, # type: ignore[attr-defined]
- output,
- )
- for return_schema, output in zip(returns, self.outputs)
- ]
- assert self.op_overload is not None
- node = ExternKernelNode(
- name=self.get_name(),
- node=export_schema.Node(
- target=self.op_overload.name(),
- inputs=named_arguments,
- outputs=output_arguments,
- metadata={},
- ),
- )
- V.extern_kernel_nodes.append(node)
- return [*args, *ordered_kwargs]
- @override
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- """Overrides the parent member.
- See https://github.com/pytorch/pytorch/issues/151692"""
- kernel = self.op_overload
- assert kernel is not None
- if kernel.namespace == "aten":
- # Aten Fallback Ops
- assert isinstance(kernel, torch._ops.OpOverload), type(kernel)
- if V.graph.cpp_wrapper:
- from torchgen.aoti.fallback_ops import inductor_fallback_ops
- if str(kernel) not in inductor_fallback_ops:
- # C shim v2 is torchgen-ed, which should cover all aten ops.
- # If you do hit a missed op, please update fallback_ops.py.
- log.warning(
- "%s is missing a c-shim implementation, using proxy executor as fallback",
- kernel,
- )
- self.use_runtime_dispatch = True
- elif kernel.namespace == "_quantized":
- # Internal Quantized Fallback Ops
- assert isinstance(kernel, torch._ops.OpOverload), type(kernel)
- elif V.graph.cpp_wrapper:
- # For non-aten OpOverload, i.e. custom ops
- # If the op is in custom_ops_to_c_shims, generate direct function call
- self.use_runtime_dispatch = (
- kernel not in config.aot_inductor.custom_ops_to_c_shims
- )
- # Handle the special case where a complex number is input to a C-shim kernel for
- # a scalar input. The torchgen'ed shim API will use type "double", which is
- # incompatible with complex numbers, forcing a fallback to runtime dispatch.
- if (
- V.graph.cpp_wrapper
- and isinstance(kernel, torch._ops.OpOverload)
- and not self.use_runtime_dispatch
- ):
- def is_number(t: torch.JitType) -> bool:
- if isinstance(t, torch.OptionalType):
- return is_number(t.getElementType())
- return isinstance(t, torch.NumberType)
- # Using unflatten_args is a bit of a hack, but all the complex arguments we
- # care about are in self.constant_args, and calling unflatten_args puts them
- # in the correct order without triggering codegen.
- args, kwargs = self.unflatten_args(self.inputs, self.constant_args)
- # Append kwarg values to args. ordered_kwargs_for_cpp_kernel is guaranteed
- # to be set, since this is an OpOverload kernel.
- args_iter = itertools.chain(
- args,
- (
- self.get_kwargs_value(k, **kwargs)
- for k in self.ordered_kwargs_for_cpp_kernel
- ),
- )
- self.use_runtime_dispatch = any(
- isinstance(v, complex) and is_number(a.real_type)
- for v, a in zip(args_iter, kernel._schema.arguments)
- )
- self.codegen_comment(wrapper)
- if self.use_runtime_dispatch:
- exported_args = self.export_extern_kernel_node()
- assert self.python_kernel_name is not None
- assert self.op_overload is not None
- wrapper.generate_fallback_kernel_with_runtime_lookup(
- self.get_name(),
- self.python_kernel_name,
- lambda: [*self.codegen_args(), *self.codegen_kwargs()],
- self.op_overload,
- exported_args,
- # NOTE: [special handling of all_reduce_coalesced_'s return value]
- self.outputs if self.outputs else self.mutation_outputs,
- )
- else:
- wrapper.generate_fallback_kernel(self)
- if isinstance(self.layout, Layout):
- self.codegen_size_asserts(wrapper)
- self.codegen_alignment_asserts(wrapper)
- self.codegen_memory_tracking(wrapper)
- self.codegen_unbacked_symbol_defs(wrapper)
- @staticmethod
- def tensor_to_layout(output: torch.Tensor) -> FixedLayout:
- is_pinned = False
- try:
- is_pinned = output.is_pinned()
- except RuntimeError:
- # dispatch not implemented
- pass
- return FixedLayout(
- output.device,
- output.dtype,
- convert_shape_to_inductor(output.size()),
- convert_shape_to_inductor(output.stride()),
- is_pinned=is_pinned,
- )
- @classmethod
- def create(cls, kernel: _OpOverloads, *args: Any, **kwargs: Any) -> FallbackKernel:
- """Create an instance of FallbackKernel from an _OpOverloads"""
- fake_incorrect_kernels = (aten._fused_moving_avg_obs_fq_helper_functional,)
- if kernel not in fake_incorrect_kernels:
- context = cast(AbstractContextManager[None], V.graph.fake_mode)
- else:
- context = nullcontext()
- with context:
- (
- example_output,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- unbacked_bindings,
- ) = cls.process_kernel(kernel, *args, **kwargs)
- # We need this extra check for input alignment since the example
- # inputs we created are always aligned.
- has_unaligned_input = any(is_unaligned(arg) for arg in tensor_args)
- device = cls.find_device(tensor_args, example_output)
- if not device and isinstance(
- kernel, torch._higher_order_ops.torchbind.CallTorchBind
- ):
- # use CPU device for torchbind methods that don't take in or output any tensor, e.g. size()
- device = torch.device("cpu")
- if example_output is None:
- packed = cls(
- NoneLayout(device=device),
- kernel,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- unbacked_bindings=unbacked_bindings,
- )
- else:
- assert device, "Not sure where to find device info"
- packed = cls(
- MultiOutputLayout(device=device),
- kernel,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- unbacked_bindings=unbacked_bindings,
- )
- def generate_output(output: Any, indices: list[tuple[Any, int]]) -> Any:
- if isinstance(output, (list, tuple)):
- return type(output)(
- generate_output(output[i], indices + [(type(output), i)])
- for i in range(len(output))
- )
- elif isinstance(output, dict):
- return {
- key: generate_output(val, indices + [(type(output), key)])
- for key, val in output.items()
- }
- elif isinstance(output, torch.Tensor):
- buf = MultiOutput(
- cls.tensor_to_layout(output),
- packed,
- indices,
- )
- if (
- config.assume_unaligned_fallback_output
- or has_unaligned_input
- or not tensor_is_aligned(output)
- ):
- V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
- return buf
- elif isinstance(output, int):
- return output
- elif isinstance(output, torch.SymInt):
- return output.node.expr
- else:
- assert output is None, (
- f"FallbackKernel output type {type(output)} is not supported"
- )
- return None
- outputs = generate_output(example_output, [])
- if isinstance(outputs, (list, tuple)):
- packed.outputs = outputs
- elif isinstance(outputs, dict):
- packed.outputs = tuple(outputs)
- else:
- packed.outputs = [outputs]
- return outputs
- def apply_constraint(self) -> None:
- return super().apply_constraint()
- @ir_dataclass(frozen=False)
- class ComplexView(FallbackKernel):
- """View a complex number as two dtyped numbers or vice versa"""
- def should_allocate(self) -> bool:
- return False
- def get_inputs_that_alias_output(self) -> Sequence[str]:
- # Signal to codegen that our output buffer isn't safe to reuse
- return [self.input_name(0)]
- def __init__(
- self,
- layout: OutputSpec,
- kernel: _OpOverloads,
- tensor_args: Sequence[IRNode],
- nontensor_args: Sequence[Any],
- unflatten_args: Callable[..., Any],
- *,
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
- ) -> None:
- super().__init__(
- layout,
- kernel,
- tensor_args,
- nontensor_args,
- unflatten_args,
- unbacked_bindings=unbacked_bindings,
- )
- class MemoryCheckKernel(FallbackKernel):
- """
- Custom kernel for memory checking that generates direct function calls
- TODO - the custom op was erroring with str inputs. should be able to custom op directly.
- """
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- """Override codegen to write direct function call"""
- # Extract our arguments from nontensor_args
- wrapper.write_memory_track_allocation_once()
- alive_list, dead_list, is_final_step = self.constant_args
- alive_repr = repr(alive_list)
- dead_repr = repr(dead_list)
- if is_final_step:
- wrapper.writeline(
- "# note: dont currently distinguish between buffers returned and dealloc'd in last step"
- )
- call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr}, is_final_step={is_final_step})"
- else:
- call = f"check_memory_step(allocated={alive_repr}, freed={dead_repr})"
- wrapper.writeline(call)
- @ir_dataclass
- class MultiOutputLayout(OutputSpec):
- device: torch.device
- def get_device(self) -> Optional[torch.device]:
- return self.device
- class MultiOutput(ExternKernel):
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.codegen_multi_output(self)
- if not self.skip_size_stride_alignment_checks:
- self.codegen_size_asserts(wrapper)
- self.codegen_alignment_asserts(wrapper)
- def __init__(
- self,
- layout: OutputSpec,
- input: IRNode,
- indices: list[tuple[Any, ...]],
- skip_size_stride_alignment_checks: bool = False,
- ) -> None:
- super().__init__(None, layout, [input], ())
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- self.indices = indices
- self.skip_size_stride_alignment_checks = skip_size_stride_alignment_checks
- @cache_on_self_and_args("MultiOutput")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- input_node = self.inputs[0]
- assert isinstance(input_node, IRNode), input_node
- return input_node.get_free_symbol_uses(unbacked_only)
- def should_allocate(self) -> bool:
- return len(self.inputs) == 1 and (
- isinstance(self.inputs[0], CppTemplateBuffer) # Grouped GEMM
- )
- def get_inputs_that_alias_output(self) -> Sequence[str]:
- return [
- inp.get_name()
- for inp in self.inputs
- if isinstance(inp, FallbackKernel)
- and len(inp.get_inputs_that_alias_output()) > 0
- ]
- # We just use a normal dataclass for MutableBox/TensorBox/StorageBox since
- # they're mainly lowering-time constructs that we expect to mutate and such.
- @dataclasses.dataclass
- class MutableBox(IRNode):
- """
- TensorBox / StorageBox allow in-place mutation of Tensors
- """
- data: IRNode
- def has_exceeded_max_reads(self) -> bool:
- return self.data.has_exceeded_max_reads()
- def get_device(self) -> Optional[torch.device]:
- return self.data.get_device()
- def make_loader(self) -> Callable[[Sequence[Expr]], OpsValue]:
- return self.data.make_loader()
- def make_indexer(self) -> Callable[[Sequence[Expr]], Expr]:
- return self.data.make_indexer()
- def get_stride(self) -> Sequence[_IntLike]:
- return self.data.get_stride()
- def get_name(self) -> str:
- return self.data.get_name()
- def has_large_inner_fn(self, threshold: Optional[int] = None) -> bool:
- return self.data.has_large_inner_fn(threshold)
- def mark_reuse(self, users: int) -> None:
- return self.data.mark_reuse(users)
- def realize_hint(self) -> None:
- return self.data.realize_hint()
- def unwrap_view(self) -> IRNode:
- return self.data.unwrap_view()
- def is_input_buffer(self) -> bool:
- return self.data.is_input_buffer()
- def freeze_layout(self) -> None:
- return self.data.freeze_layout()
- def freeze_layout_with_stride_order(
- self, order: Sequence[int], allow_padding: bool = False
- ) -> None:
- return self.data.freeze_layout_with_stride_order(order, allow_padding)
- def freeze_layout_with_fill_order(self, order: Sequence[int]) -> None:
- return self.data.freeze_layout_with_fill_order(order)
- def freeze_layout_with_same_order(self, stride: Sequence[_IntLike]) -> None:
- return self.data.freeze_layout_with_same_order(stride)
- def freeze_layout_with_exact_strides(
- self, exact_strides: Sequence[_IntLike], allow_padding: bool = False
- ) -> None:
- return self.data.freeze_layout_with_exact_strides(exact_strides, allow_padding)
- def get_read_writes(self) -> dependencies.ReadWrites:
- return self.data.get_read_writes()
- def get_reads(self) -> OrderedSet[Dep]:
- return self.data.get_reads()
- def num_reads(self) -> int:
- return self.data.num_reads()
- def get_storage_numel(self) -> _IntLike:
- return self.data.get_storage_numel()
- def get_reduction_type(self) -> Optional[str]:
- return self.data.get_reduction_type()
- def get_reduction_size(self) -> Sequence[Expr]:
- return self.data.get_reduction_size()
- def is_extern(self) -> bool:
- return self.data.is_extern()
- def is_no_op(self) -> bool:
- return self.data.is_no_op()
- def constant_to_device(self, device: torch.device) -> IRNode:
- return self.data.constant_to_device(device)
- def get_mutation_names(self) -> Sequence[str]:
- return self.data.get_mutation_names()
- def get_operation_name(self) -> str:
- return self.data.get_operation_name()
- def get_inputs_that_alias_output(self) -> Sequence[str]:
- return self.data.get_inputs_that_alias_output()
- def realize(self) -> Optional[str]:
- return self.data.realize()
- @cache_on_self_and_args("MutableBox")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return self.data.get_free_symbol_uses(unbacked_only)
- def get_read_names(self) -> OrderedSet[str]:
- return self.data.get_read_names()
- def get_defining_op(self) -> Optional[Operation]:
- return self.data.get_defining_op()
- def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
- return self.data.codegen_reference(writer)
- @property
- def layout(self) -> OutputSpec:
- # we intentionally call get_output_spec (rather than get_layout) since Buffer.layout is an OutputSpec
- return self.data.get_output_spec()
- def get_layout(self) -> Layout:
- return self.data.get_layout()
- def get_output_spec(self) -> OutputSpec:
- return self.data.get_output_spec()
- def get_size(self) -> Sequence[Expr]:
- return self.data.get_size()
- @property
- def dtype(self) -> torch.dtype:
- return self.data.dtype
- def __str__(self) -> str:
- if isinstance(self.data, MutableBox):
- line0 = f"{type(self).__name__}({type(self.data).__name__}("
- endl = "))"
- inner = self.data.data
- else:
- line0 = f"{type(self).__name__}("
- inner = self.data
- endl = ")"
- lines = [
- line0,
- indent(str(inner)),
- endl,
- ]
- return "\n".join(lines)
- __repr__ = __str__
- class TensorBox(MutableBox):
- @staticmethod
- def create(data: IRNode) -> Union[TensorBox, ShapeAsConstantBuffer]:
- if isinstance(data, ShapeAsConstantBuffer):
- return data
- return TensorBox(StorageBox(data))
- class StorageBox(MutableBox):
- """
- StorageBox allow in-place mutation of Tensors
- """
- def is_input_buffer(self) -> bool:
- if isinstance(self.data, (InputBuffer, ReinterpretView)):
- return self.data.get_name() in V.graph.graph_inputs
- return False
- def is_module_buffer(self) -> bool:
- return (
- isinstance(self.data, (ConstantBuffer))
- and self.data.get_name() in V.graph.constants
- )
- def realize(self) -> Optional[str]:
- if IRNode.is_realized_node(self.data):
- return self.data.get_name()
- assert isinstance(self.data, (Pointwise, Reduction, Scan, Sort)), type(
- self.data
- )
- origin_node = self.data.get_origin_node()
- traceback = self.data.get_traceback()
- device = self.data.get_device()
- assert device is not None
- self.data = ComputedBuffer(
- name=None,
- layout=FlexibleLayout(
- device=device,
- dtype=self.data.get_dtype(),
- size=self.data.get_size(),
- is_pinned=False,
- ),
- data=self.data,
- )
- self.data.name = V.graph.register_buffer(self.data)
- V.graph.register_operation(self.data)
- self.data.origins = self.origins
- self.data.origin_node = origin_node
- self.data.traceback = traceback
- return self.data.name
- def realize_hint(self) -> None:
- """
- Called on buffers we expect to be forced to realize later.
- """
- if (
- isinstance(self.data, (Pointwise, Reduction))
- and self.data.inner_fn_opcount().nontrivial_read_count > 1
- ):
- self.realize()
- def has_accumulated_enough_reads_by_size(self, threshold: int) -> bool:
- return (
- sum(V.graph.get_dep_size_hint(dep) for dep in self.get_reads()) > threshold
- )
- def has_exceeded_max_reads(self) -> bool:
- return isinstance(self.data, Pointwise) and (
- self.num_reads() > config.realize_acc_reads_threshold
- or self.has_large_inner_fn()
- or (
- config.realize_acc_reads_size_threshold is not None
- and self.has_accumulated_enough_reads_by_size(
- config.realize_acc_reads_size_threshold
- )
- )
- )
- def should_realize_on_reuse(self, users: int) -> bool:
- """
- A heuristic to decide if we should realize a tensor
- that is used multiple times.
- """
- if users > 1 and isinstance(self.data, (Pointwise, Reduction)):
- if is_cpu(self.data):
- # Heuristic for realizing reused result of heavy ops on cpu
- opcount = self.data.inner_fn_opcount()
- heavy_ops = ["exp", "sigmoid"] # a list of heavy ops
- if any(x in opcount.used_ops for x in heavy_ops):
- return True
- return (
- self.num_reads() > config.realize_reads_threshold
- or self.has_large_inner_fn()
- )
- return False
- def mark_reuse(self, users: int) -> None:
- if self.should_realize_on_reuse(users):
- self.realize()
- def num_reads(self) -> int:
- return self.data.num_reads()
- @ir_dataclass(frozen=False)
- class Subgraph(IRNode):
- name: str
- graph_module: torch.fx.GraphModule
- graph: Optional[GraphLowering] = None
- def _has_aliased_buffers(buffers: Sequence[IRNode]) -> bool:
- buffers = [
- buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer
- for buffer in buffers
- ]
- # assuming the same buffer is represented by the same IRNode object
- return len(OrderedSet(id(buffer) for buffer in buffers)) < len(buffers)
- @ir_dataclass(frozen=False)
- class InvokeSubgraph(ExternKernel):
- """
- Ir node for the invoke_subgraph HOP.
- """
- subgraph: Optional[Subgraph] = None
- operands: Optional[Sequence[IRNode]] = None
- outputs: Optional[Sequence[IRNode]] = None
- def __init__(
- self, subgraph: Subgraph, operands: Sequence[IRNode], layout: MultiOutputLayout
- ) -> None:
- super().__init__(
- name=None,
- layout=layout,
- inputs=operands,
- )
- self.subgraph = subgraph
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- @classmethod
- def create(
- cls, subgraph: Subgraph, *operands: IRNode
- ) -> list[Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]]:
- """For each operand, get a realized input, force it to have the same
- strides as the subgraph inputs, then use an InvokeSubgraph"""
- from .lowering import constrain_to_fake_tensor
- # TODO(anijain2305) - Support sym expr as operands in future.
- current_node = V.graph.current_node
- fake_operands = None
- if eager_input_vals := current_node.meta.get("eager_input_vals"):
- # eager_input_vals is (args_values, kwargs_values). We need args for invoke_subgraph
- fake_operands = eager_input_vals[0][2:]
- else:
- # For the partitioned backward graph, we do not have
- # eager_input_vals. Here, we rely on the recorded example values.
- fx_operands = current_node.args[2:]
- fake_operands = [x.meta["val"] for x in fx_operands] # type: ignore[union-attr]
- # Realize the inputs. Also intermediates can have different strides than
- # the inputs of the subgraph. So, force the intermediates to have same
- # strides as that of subgraph inputs.
- operands: list[IRNode] = [cls.realize_input(x) for x in operands]
- new_operands: list[IRNode] = []
- for idx, operand in enumerate(operands):
- if isinstance(operand, (ShapeAsConstantBuffer, GeneratorState)):
- new_operands.append(operand)
- else:
- new_operands.append(
- constrain_to_fake_tensor(operand, fake_operands[idx])
- )
- operands = new_operands
- if subgraph.graph is None:
- # create and lower subgraphs
- subgraph.graph = V.graph.make_subgraph(
- gm=subgraph.graph_module,
- example_inputs=fake_operands,
- subgraph_name=subgraph.name,
- )
- with V.set_graph_handler(subgraph.graph):
- subgraph.graph.run(*fake_operands)
- outputs = subgraph.graph.graph_outputs
- # Find the device - operands could be integers from shapes, so we can't
- # use operands[0]
- device = None
- for operand in operands:
- if not isinstance(operand, ShapeAsConstantBuffer):
- device = operand.get_device()
- break
- assert device is not None
- invoke_subgraph = InvokeSubgraph(
- subgraph=subgraph,
- operands=operands,
- layout=MultiOutputLayout(device=device),
- )
- def create_output(
- output: IRNode, ind: int
- ) -> Union[ShapeAsConstantBuffer, NoneAsConstantBuffer, MultiOutput]:
- if isinstance(output, (ShapeAsConstantBuffer, NoneAsConstantBuffer)):
- return output
- else:
- device = output.get_device()
- assert device is not None
- return MultiOutput(
- FixedLayout(
- device=device,
- dtype=output.get_dtype(),
- size=output.get_size(),
- stride=output.get_stride(),
- offset=output.get_layout().offset,
- is_pinned=output.get_layout().is_pinned,
- ),
- invoke_subgraph, # type: ignore[has-type]
- [(list, ind)],
- skip_size_stride_alignment_checks=True,
- )
- outs = [create_output(output, i) for i, output in enumerate(outputs)]
- invoke_subgraph.outputs = outs # type: ignore[assignment]
- return outs
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.codegen_invoke_subgraph(self)
- @ir_dataclass(frozen=False)
- class Conditional(ExternKernel):
- predicate: Optional[IRNode] = None
- operands: Optional[Sequence[IRNode]] = None
- true_subgraph: Optional[Subgraph] = None
- false_subgraph: Optional[Subgraph] = None
- outputs: Optional[Sequence[MultiOutput]] = None
- def __init__(
- self,
- predicate: IRNode,
- operands: Sequence[IRNode],
- true_subgraph: Subgraph,
- false_subgraph: Subgraph,
- layout: MultiOutputLayout,
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
- ) -> None:
- self.predicate = predicate
- self.operands = operands
- self.true_subgraph = true_subgraph
- self.false_subgraph = false_subgraph
- sym_args, tensor_args = _split_by_sym_type([predicate, *operands])
- super().__init__(
- name=None,
- layout=layout,
- inputs=tensor_args,
- constant_args=sym_args,
- )
- if unbacked_bindings is not None:
- self.unbacked_bindings = unbacked_bindings
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- @staticmethod
- def _maybe_expr(s: Union[int, torch.SymInt]) -> Union[int, sympy.Expr]:
- if isinstance(s, int):
- return s
- return s.node.expr
- @classmethod
- def create(
- cls,
- predicate: TensorBox,
- true_fn: Subgraph,
- false_fn: Subgraph,
- operands: list[Union[TensorBox, ShapeAsConstantBuffer]],
- ) -> Sequence[IRNode]:
- """Create a Sequence of IRNodes from a conditional statement (see .lowering.cond)"""
- predicate = cls.realize_input(predicate)
- operands = [cls.realize_input(x) for x in operands]
- fx_operands: Argument = V.graph.current_node.args[-1]
- assert isinstance(fx_operands, Sequence), type(fx_operands)
- assert all(isinstance(n, Node) for n in fx_operands)
- fake_operands = [cast(Node, x).meta["val"] for x in fx_operands]
- for subgraph in (true_fn, false_fn):
- if subgraph.graph is None:
- # create and lower subgraphs
- subgraph.graph = V.graph.make_subgraph(
- gm=subgraph.graph_module,
- example_inputs=fake_operands,
- subgraph_name=subgraph.name,
- )
- with V.set_graph_handler(subgraph.graph):
- subgraph.graph.run(*fake_operands)
- assert true_fn.graph is not None
- assert false_fn.graph is not None
- true_outputs = true_fn.graph.graph_outputs
- false_outputs = false_fn.graph.graph_outputs
- for name, outputs in (("true_fn", true_outputs), ("false_fn", false_outputs)):
- if _has_aliased_buffers(true_outputs):
- raise AssertionError(
- "Output aliasing is currently not supported in compiled torch.cond. "
- f"The outputs of the {name} subgraph of torch.cond are aliased: {outputs}"
- )
- # make sure true and false outputs are structurally equivalent
- assert len(true_outputs) == len(false_outputs), (true_outputs, false_outputs)
- for i, (t_o, f_o) in enumerate(zip(true_outputs, false_outputs)):
- assert t_o.get_device() == f_o.get_device(), (i, t_o, f_o)
- assert t_o.get_dtype() == f_o.get_dtype(), (i, t_o, f_o)
- assert t_o.get_layout().offset == f_o.get_layout().offset, (i, t_o, f_o)
- device = next(
- o.get_device()
- for o in [predicate] + operands
- if not isinstance(o, ShapeAsConstantBuffer)
- )
- unbacked_bindings = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env,
- V.graph.current_node.meta.get("unbacked_bindings", None),
- )
- assert device is not None, "cannot determine device"
- conditional = Conditional(
- predicate=predicate,
- operands=operands,
- true_subgraph=true_fn,
- false_subgraph=false_fn,
- layout=MultiOutputLayout(device=device),
- unbacked_bindings=unbacked_bindings,
- )
- outputs = [
- MultiOutput(
- FixedLayout(
- device=device,
- dtype=output.get_dtype(),
- size=[Conditional._maybe_expr(sz) for sz in merged_output.size()],
- stride=[
- Conditional._maybe_expr(sz) for sz in merged_output.stride()
- ],
- offset=output.get_layout().offset,
- is_pinned=output.get_layout().is_pinned,
- ),
- conditional,
- [(list, i)],
- )
- # as the true and false outputs are equivalent,
- # we can use either of them here as a "template"
- for i, (output, merged_output) in enumerate(
- zip(true_outputs, V.graph.current_node.meta["val"])
- )
- ]
- conditional.outputs = outputs # type: ignore[assignment]
- return outputs
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.codegen_conditional(self)
- wrapper.codegen_unbacked_symbol_defs_for_outputs(
- self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
- )
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- if unbacked_bindings := getattr(self, "unbacked_bindings", None):
- resolved = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env, unbacked_bindings
- )
- assert resolved is not None
- return OrderedSet(resolved.keys())
- else:
- return OrderedSet()
- def _split_by_sym_type(
- args: list[Any],
- ) -> tuple[list[ShapeAsConstantBuffer], list[Any]]:
- non_sym_args = []
- sym_args = []
- for arg in args:
- if isinstance(arg, ShapeAsConstantBuffer):
- sym_args.append(arg.expr)
- else:
- non_sym_args.append(arg)
- return sym_args, non_sym_args
- @ir_dataclass(frozen=False)
- class WhileLoop(ExternKernel):
- """The IR node for while_loop and while_loop_stack_output. It supports input mutation."""
- carried_inputs: Optional[Sequence[IRNode]] = None
- additional_inputs: Optional[Sequence[IRNode]] = None
- cond_subgraph: Optional[Subgraph] = None
- body_subgraph: Optional[Subgraph] = None
- outputs: Optional[Sequence[MultiOutput]] = None
- def __init__(
- self,
- carried_inputs: Sequence[IRNode],
- additional_inputs: Sequence[IRNode],
- cond_subgraph: Subgraph,
- body_subgraph: Subgraph,
- layout: MultiOutputLayout,
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]],
- stack_output: bool,
- ) -> None:
- self.carried_inputs = carried_inputs
- self.additional_inputs = additional_inputs
- self.cond_subgraph = cond_subgraph
- self.body_subgraph = body_subgraph
- sym_args, tensor_args = _split_by_sym_type(
- [*carried_inputs, *additional_inputs]
- )
- super().__init__(
- name=None,
- layout=layout,
- inputs=tensor_args,
- constant_args=sym_args,
- )
- if unbacked_bindings is not None:
- self.unbacked_bindings = unbacked_bindings
- self.stack_output = stack_output
- self.name = V.graph.register_buffer(self)
- V.graph.register_operation(self)
- # Accidental aliasing can be created due to cse, where the empty buffers we
- # allocated for backward to use gets csed into the same buffer in function fx_graph_cse.
- # See test_scan_multiple_layers_gradient for a concrete example.
- @staticmethod
- def _clone_aliased_inputs(carried_inputs: Sequence[IRNode]) -> Sequence[IRNode]:
- if not _has_aliased_buffers(carried_inputs):
- return carried_inputs
- # Import clone from lowering module
- from .lowering import clone
- # Unwrap views to get the underlying buffers for comparison
- unwrapped_buffers = [
- buffer.unwrap_view() if isinstance(buffer, ReinterpretView) else buffer
- for buffer in carried_inputs
- ]
- # Track which buffers we've seen and their indices
- seen_buffers: OrderedSet[int] = OrderedSet()
- result = []
- for i, (original_input, unwrapped_buffer) in enumerate(
- zip(carried_inputs, unwrapped_buffers)
- ):
- if id(unwrapped_buffer) in seen_buffers:
- result.append(clone(original_input))
- else:
- seen_buffers.add(id(unwrapped_buffer))
- result.append(original_input)
- return result
- @classmethod
- def create(
- cls,
- cond_fn: Subgraph,
- body_fn: Subgraph,
- carried_inputs: Sequence[IRNode],
- additional_inputs: Sequence[IRNode],
- stack_output: bool,
- ) -> Union[IRNode, Sequence[IRNode]]:
- """create the while_loop IR node. stack_output controls whether it stack
- each iterations' output, which is necessary for training.
- """
- from torch._higher_order_ops.utils import check_input_alias_and_mutation
- def _require_exact_strides(
- tensor_boxes: Sequence[IRNode],
- fake_tensors: list[Union[int, torch.SymInt, torch.Tensor]],
- ) -> list[IRNode]:
- assert len(tensor_boxes) == len(fake_tensors)
- ret = []
- for tb, fk in zip(tensor_boxes, fake_tensors):
- if isinstance(fk, torch.Tensor):
- ret.append(
- ExternKernel.require_exact_strides(
- tb, fk.stride(), allow_padding=False
- )
- )
- else:
- ret.append(tb)
- return ret
- fx_carried_inputs = V.graph.current_node.args[-2]
- fx_additional_inputs = V.graph.current_node.args[-1]
- fx_all_inputs = fx_carried_inputs + fx_additional_inputs # type: ignore[operator]
- fake_all_inputs = [x.meta["val"] for x in fx_all_inputs] # type: ignore[union-attr]
- fake_carried_inputs = [x.meta["val"] for x in fx_carried_inputs] # type: ignore[union-attr]
- fake_additional_inputs = [x.meta["val"] for x in fx_additional_inputs] # type: ignore[union-attr]
- carried_inputs_ = [cls.realize_input(x) for x in carried_inputs]
- carried_inputs_ = WhileLoop._clone_aliased_inputs(carried_inputs_)
- carried_inputs_ = _require_exact_strides(carried_inputs_, fake_carried_inputs)
- additional_inputs_ = [cls.realize_input(x) for x in additional_inputs]
- additional_inputs_ = _require_exact_strides(
- additional_inputs_, fake_additional_inputs
- )
- all_inputs = carried_inputs_ + additional_inputs_
- for subgraph in (cond_fn, body_fn):
- if subgraph.graph is None:
- # create and lower subgraphs
- assert isinstance(fx_all_inputs, Sequence), type(fx_all_inputs)
- subgraph.graph = V.graph.make_subgraph(
- gm=subgraph.graph_module,
- example_inputs=fx_all_inputs, # type: ignore[arg-type]
- subgraph_name=subgraph.name,
- )
- with V.set_graph_handler(subgraph.graph):
- subgraph.graph.run(*fake_all_inputs)
- # For body_fn, we require its output to have the exact same stride
- # as inputs because the previous output is the input of next iteration.
- #
- # This cannot be automatically done in graph lowering because body_fn's graph outputs
- # are not user-facing so the special handling for strides of user-facing output in graph
- # lowering is not applicable.
- if subgraph is body_fn:
- assert len(subgraph.graph.graph_outputs) == len(
- fake_carried_inputs
- )
- subgraph.graph.graph_outputs = _require_exact_strides( # type: ignore[assignment]
- subgraph.graph.graph_outputs,
- fake_carried_inputs,
- )
- assert cond_fn.graph and body_fn.graph
- cond_outputs = cond_fn.graph.graph_outputs
- body_outputs = body_fn.graph.graph_outputs
- if _has_aliased_buffers(body_outputs):
- raise AssertionError(
- "Output aliasing is currently not supported in compiled torch.while_loop. "
- f"The outputs of the body_fn subgraph of torch.while_loop are aliased: {body_outputs}"
- )
- # make sure cond_fn returns a boolean scalar Tensor
- assert len(cond_outputs) == 1, cond_outputs
- p = cond_outputs[0]
- if not isinstance(p, ShapeAsConstantBuffer):
- assert p.get_dtype() == torch.bool, p
- assert len(p.get_size()) == 0, p
- assert len(all_inputs) > 0, (
- "torch.while_loop is assumed to have at least one operand."
- )
- device = all_inputs[0].get_device()
- assert device is not None # to make linter happy
- # make sure carried_inputs_ and body outputs are structurally equivalent
- assert len(carried_inputs_) == len(body_outputs), (
- carried_inputs_,
- body_outputs,
- )
- for i, (op, bo) in enumerate(zip(carried_inputs_, body_outputs)):
- def _guard_list_equals(
- lhs_exprs: Sequence[Union[int, sympy.Expr]],
- rhs_exprs: Sequence[Union[int, sympy.Expr]],
- ) -> None:
- assert len(lhs_exprs) == len(rhs_exprs)
- for lhs, rhs in zip(lhs_exprs, rhs_exprs):
- V.graph.sizevars.check_equals(lhs, rhs)
- _guard_list_equals(op.get_size(), bo.get_size())
- _guard_list_equals(op.get_stride(), bo.get_stride())
- # assume all carried_inputs_ and outputs are on the same device
- # as the MultiOutputLayout below requires single device
- assert op.get_device() == bo.get_device(), (i, op, bo, device)
- assert op.get_dtype() == bo.get_dtype(), (i, op, bo)
- assert device is not None
- unbacked_bindings = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env,
- V.graph.current_node.meta.get("unbacked_bindings", None),
- )
- while_loop = WhileLoop(
- carried_inputs=carried_inputs_,
- additional_inputs=additional_inputs_,
- cond_subgraph=cond_fn,
- body_subgraph=body_fn,
- # asserted above that there is at least one operand
- layout=MultiOutputLayout(device=device),
- unbacked_bindings=unbacked_bindings,
- stack_output=stack_output,
- )
- assert body_fn.graph is not None and isinstance(
- body_fn.graph.module, torch.fx.GraphModule
- ) # to make linter happy
- # Handling input mutations
- mutated_idxs = check_input_alias_and_mutation(
- body_fn.graph.module, fake_all_inputs
- )[3]
- mutated_idx_set = OrderedSet(mutated_idxs)
- mutated_inputs = [all_inputs[idx] for idx in mutated_idx_set]
- # Create all outputs first
- mutated_inputs_iter = iter(mutated_inputs)
- all_outputs: list[IRNode] = []
- while_loop.outputs = []
- while_loop.mutation_outputs = []
- if stack_output:
- assert len(mutated_idx_set) == 0, (
- "NYI: while_loop_stack_output input mutations."
- )
- for idx, output in enumerate(V.graph.current_node.meta["val"]):
- # Create MultiOutput for regular outputs
- multi_out = MultiOutput(
- FixedLayout(
- device=output.device, # type: ignore[arg-type]
- dtype=output.dtype,
- size=[Conditional._maybe_expr(sz) for sz in output.size()],
- stride=[Conditional._maybe_expr(st) for st in output.stride()],
- ),
- while_loop,
- [(list, idx)],
- )
- while_loop.outputs.append(multi_out)
- all_outputs.append(multi_out)
- else:
- for idx, output in enumerate(body_outputs):
- if idx in mutated_idx_set:
- assert idx < len(carried_inputs), "only carries can be mutated."
- # Create MutationOutput for mutated inputs
- mutated_input = next(mutated_inputs_iter)
- while_loop.mutation_outputs.append(
- MutationOutput(mutated_input.layout, mutated_input, while_loop) # type: ignore[attr-defined, union-attr]
- )
- all_outputs.append(mutated_input)
- else:
- multi_out = MultiOutput(
- FixedLayout(
- device=output.get_device(), # type: ignore[arg-type]
- dtype=output.get_dtype(),
- size=output.get_size(),
- stride=output.get_stride(),
- offset=output.get_layout().offset,
- ),
- while_loop,
- [(list, idx)],
- )
- while_loop.outputs.append(multi_out)
- all_outputs.append(multi_out)
- for inp, out in zip(carried_inputs, all_outputs):
- if inp.get_name() in V.graph.graph_inputs:
- # if a carried input of the while_loop is a graph input,
- # it can be returned as is when the number of iterations
- # is zero. due to this, we can't (generally) reuse the
- # output buffers corresponding to the graph inputs, as
- # the inputs may end up being mutated.
- V.graph.never_reuse_buffers.add(out.get_name())
- return all_outputs
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.codegen_while_loop(self, self.stack_output)
- wrapper.codegen_unbacked_symbol_defs_for_outputs(
- self.get_name(), self.outputs, getattr(self, "unbacked_bindings", {})
- )
- def get_unbacked_symbol_defs(self) -> OrderedSet[sympy.Symbol]:
- if unbacked_bindings := getattr(self, "unbacked_bindings", None):
- resolved = resolve_unbacked_bindings(
- V.graph.sizevars.shape_env, unbacked_bindings
- )
- assert resolved is not None
- return OrderedSet(resolved.keys())
- else:
- return OrderedSet()
- class EffectfulKernel(FallbackKernel):
- def __init__(
- self,
- layout: OutputSpec,
- kernel: _OpOverloads,
- tensor_args: Sequence[IRNode],
- nontensor_args: Sequence[Any],
- unflatten_args: Callable[..., Any],
- kwargs: Optional[dict[str, Any]] = None,
- *,
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
- ) -> None:
- super().__init__(
- layout,
- kernel,
- tensor_args,
- nontensor_args,
- unflatten_args,
- kwargs=None,
- unbacked_bindings=unbacked_bindings,
- )
- from torch._higher_order_ops.effects import get_effect_key
- uncovered_args = [
- a.value if isinstance(a, TorchBindObject) else a for a in tensor_args
- ]
- effect_type = get_effect_key(kernel, (*nontensor_args, *uncovered_args), kwargs)
- assert effect_type is not None
- self.effect_type = effect_type
- self.prev_effect_buffer = V.graph.effectful_ops.get(effect_type, None)
- V.graph.effectful_ops[effect_type] = self
- def get_read_writes(self) -> dependencies.ReadWrites:
- read_writes = super().get_read_writes()
- if self.prev_effect_buffer is not None:
- read_writes.reads.add(
- dependencies.StarDep(self.prev_effect_buffer.get_name())
- )
- return read_writes
- def has_side_effects(self) -> bool:
- return True
- class NonTensorObj(IRNode):
- @cache_on_self_and_args("NonTensorObj")
- def get_free_symbol_uses(
- self, unbacked_only: bool = False
- ) -> OrderedSet[sympy.Symbol]:
- return OrderedSet()
- @ir_dataclass
- class TorchBindObject(NonTensorObj):
- name: str
- value: Union[FakeScriptObject, torch.ScriptObject]
- def get_name(self) -> str:
- return self.name
- def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
- return self.name
- def get_value(self) -> Union[FakeScriptObject, torch.ScriptObject]:
- return self.value
- def get_real_obj(self) -> torch.ScriptObject:
- if isinstance(self.value, torch.ScriptObject):
- return self.value
- else:
- return self.value.real_obj
- def get_buf_bytes(self) -> int:
- # Returns the sum of all tensors in the flattened object
- real_script_obj = self.get_real_obj()
- assert hasattr(real_script_obj, "__obj_flatten__")
- flat_dict = dict(real_script_obj.__obj_flatten__())
- flat_elems = pytree.tree_flatten(flat_dict)[0]
- flat_sizes = [
- x.element_size() * x.numel()
- for x in flat_elems
- if isinstance(x, torch.Tensor)
- ]
- return functools.reduce(operator.add, flat_sizes, 0)
- @ir_dataclass
- class GeneratorState(NonTensorObj):
- name: str
- device: torch.device
- def get_name(self) -> str:
- return self.name
- def codegen_reference(self, writer: Optional[IndentedBuffer] = None) -> str:
- return self.name
- class _CollectiveKernel(FallbackKernel):
- def should_allocate(self) -> bool:
- return False
- def has_side_effects(self) -> bool:
- return True
- # This is identical to FallbackKernel.set_cpp_kernel(), minus the
- # part that checks against input aliasing and mutation.
- def set_cpp_kernel_name(self, cpp_kernel_name: Optional[str] = None) -> None:
- assert type(self.op_overload) is torch._ops.OpOverload, (
- "Setting cpp kernel needs a valid op_overload"
- )
- kernel = self.op_overload
- if cpp_kernel_name is not None:
- self.cpp_kernel_name = cpp_kernel_name
- else:
- self.cpp_kernel_name = kernel._schema.name
- self.ordered_kwargs_for_cpp_kernel = [
- x.name for x in kernel._schema.arguments if x.kwarg_only
- ]
- # NOTE: [In-Place Collective Safety]
- # Between the initiation and completion of an in-place collective, the
- # input buffers are subject to both volatile reads and volatile writes.
- # They must not be read, written to or reused by another kernel. To ensure
- # the constraints, we model collective -> wait_tensor as as two-step
- # mutation of the input buffers.
- @classmethod
- def create_inplace(
- cls,
- kernel: _OpOverloads,
- inputs: Union[IRNode, list[IRNode]],
- *args: Any,
- **kwargs: Any,
- ) -> None:
- with V.graph.fake_mode:
- (
- _example_output,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- unbacked_bindings,
- ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
- assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
- for tensor_arg in tensor_args:
- tensor_arg.realize()
- device = tensor_args[0].get_device()
- packed = cls(
- NoneLayout(device=device),
- kernel,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- )
- inps = pytree.tree_leaves(inputs)
- packed.mutation_outputs.extend(
- [MutationOutput(NoneLayout(device=device), buf, packed) for buf in inps]
- )
- # For inplace collective ops, the input is guaranteed to be alias of the returned value of op.
- packed.alias_names.extend([inp.get_name() for inp in inps])
- if "out" in kwargs:
- packed.mutation_outputs.append(
- MutationOutput(NoneLayout(device=device), kwargs["out"], packed)
- )
- # For out-variant collective ops, the `out=` arg is guaranteed to be alias of the returned value of op.
- packed.alias_names.append(kwargs["out"].get_name())
- # NOTE: [Out-of-Place Collective Safety]
- # Between the initiation and completion of an out-of-place collective:
- #
- # Input buffers:
- # - Are subject to volatile reads
- # - Can be read by another kernel
- # - Must not be written to or reused by another kernel
- #
- # Output buffers:
- # - Are subject to volatile writes
- # - Must not be read, written to or reused by another kernel
- #
- # To ensure the safety of input buffers without sacrificing read
- # availability, we add input buffers as read deps of wait_tensor kernels.
- #
- # To ensure the safety of output buffers, we model wait_tensor as a
- # mutation to the output buffer. Note we also assumes the user program being
- # correct and the output buffer is not consumed by kernels other than
- # wait_tensor.
- #
- # TODO(yifu): add a pre-grad pass to validate the correctness of collective
- # usage in the user program.
- @classmethod
- def create_out_of_place(
- cls,
- kernel: _OpOverloads,
- inputs: Union[TensorBox, list[TensorBox]],
- *args: Any,
- **kwargs: Any,
- ) -> Union[list[MultiOutput], _CollectiveKernel]:
- with V.graph.fake_mode:
- (
- example_output,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- unbacked_bindings,
- ) = cls.process_kernel(kernel, inputs, *args, **kwargs)
- assert not unbacked_bindings, f"{kernel}, {unbacked_bindings}"
- for tensor_arg in tensor_args:
- tensor_arg.realize()
- if isinstance(example_output, list):
- device = cls.find_device(tensor_args, example_output)
- assert device is not None
- packed = cls(
- MultiOutputLayout(device=device),
- kernel,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- )
- packed.outputs = [
- MultiOutput(
- cls.tensor_to_layout(tensor),
- packed,
- [(list, i)],
- )
- for i, tensor in enumerate(example_output)
- ]
- for buf, tensor in zip(packed.outputs, example_output):
- if config.assume_unaligned_fallback_output or not tensor_is_aligned(
- tensor
- ):
- V.graph.unaligned_buffers.add(buf.name) # type: ignore[arg-type]
- return packed.outputs
- else:
- packed = cls(
- cls.tensor_to_layout(example_output),
- kernel,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- )
- if config.assume_unaligned_fallback_output or not tensor_is_aligned(
- example_output
- ):
- V.graph.unaligned_buffers.add(packed.name) # type: ignore[arg-type]
- packed.outputs = [packed]
- return packed
- class _AllReduce_Kernel(_CollectiveKernel):
- def __init__(
- self,
- layout: OutputSpec,
- kernel: _OpOverloads,
- tensor_args: Sequence[IRNode],
- nontensor_args: Sequence[Any],
- unflatten_args: Callable[..., Any],
- kwargs: Optional[dict[str, Any]] = None,
- *,
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
- ) -> None:
- super().__init__(
- layout,
- kernel,
- tensor_args,
- nontensor_args,
- unflatten_args,
- kwargs=None,
- unbacked_bindings=unbacked_bindings,
- )
- self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce_")
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
- wrapper.generate_extern_kernel_alloc(self)
- if isinstance(self.layout, Layout):
- self.codegen_size_asserts(wrapper)
- class _AllReduceKernel(_CollectiveKernel):
- def __init__(
- self,
- layout: OutputSpec,
- kernel: _OpOverloads,
- tensor_args: Sequence[IRNode],
- nontensor_args: Sequence[Any],
- unflatten_args: Callable[..., Any],
- kwargs: Optional[dict[str, Any]] = None,
- *,
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
- ) -> None:
- super().__init__(
- layout,
- kernel,
- tensor_args,
- nontensor_args,
- unflatten_args,
- kwargs=None,
- unbacked_bindings=unbacked_bindings,
- )
- self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_all_reduce")
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
- wrapper.generate_extern_kernel_alloc(self)
- if isinstance(self.layout, Layout):
- self.codegen_size_asserts(wrapper)
- class _WaitKernel(_CollectiveKernel):
- def __init__(
- self,
- layout: OutputSpec,
- kernel: _OpOverloads,
- tensor_args: Sequence[IRNode],
- nontensor_args: Sequence[Any],
- unflatten_args: Callable[..., Any],
- kwargs: Optional[dict[str, Any]] = None,
- *,
- unbacked_bindings: Optional[dict[sympy.Symbol, pytree.KeyPath]] = None,
- ) -> None:
- super().__init__(
- layout,
- kernel,
- tensor_args,
- nontensor_args,
- unflatten_args,
- kwargs=None,
- unbacked_bindings=unbacked_bindings,
- )
- self.set_cpp_kernel_name("aoti_torch_cpu__c10d_functional_wait_tensor")
- def codegen(self, wrapper: PythonWrapperCodegen) -> None:
- wrapper.include_extra_header("torch/csrc/inductor/aoti_torch/c/shim_cpu.h")
- wrapper.generate_extern_kernel_alloc(self)
- if isinstance(self.layout, Layout):
- self.codegen_size_asserts(wrapper)
- def get_volatile_reads(self) -> Sequence[IRNode]:
- inp = self.inputs[0]
- assert isinstance(inp, IRNode)
- if isinstance(inp, _CollectiveKernel):
- # Out-of-place single-output
- i = inp.inputs[0]
- assert isinstance(i, IRNode), type(i)
- return [i]
- elif isinstance(inp, MultiOutput):
- # This can be two things:
- # 1. Out-of-place multi-output coll
- # 2. In-place coll with inputs coming from another MultiOutput
- coll = inp.inputs[0]
- # Case 1
- if isinstance(coll, _CollectiveKernel):
- _, idx = inp.indices[0]
- return [coll.inputs[idx]]
- # Case 2
- return []
- else:
- # In-place requires no additional deps handling for volatile
- # reads since the inputs are mutated.
- return []
- @classmethod
- def create_wait(cls, kernel: _OpOverloads, inp: TensorBox) -> None:
- with V.graph.fake_mode:
- (
- _example_output,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- unbacked_bindings,
- ) = cls.process_kernel(kernel, inp)
- assert not unbacked_bindings, f"{kernel} {unbacked_bindings}"
- packed = cls(
- NoneLayout(device=inp.get_device()),
- kernel,
- tensor_args,
- non_tensor_args,
- unflatten_args,
- )
- packed.mutation_outputs.append(
- MutationOutput(NoneLayout(device=inp.get_device()), inp, packed)
- )
- def get_read_writes(self) -> dependencies.ReadWrites:
- read_writes = super().get_read_writes()
- # See [Out-of-Place Collective Safety].
- volatile_reads = self.get_volatile_reads()
- for vr in volatile_reads:
- read_writes.reads.add(dependencies.StarDep(vr.get_name()))
- return read_writes
- # NB: recursive structure here reflects val_to_arg_str, avoid
- # calling free_unbacked_symbols on "exotic" types that don't get pexpr
- # treatment
- def maybe_free_unbacked_symbols(s: object) -> OrderedSet[Symbol]:
- if isinstance(s, (SymTypes, Expr)):
- # This branch should be impossible in return position
- return free_unbacked_symbols(s)
- elif isinstance(s, (tuple, list)):
- r = OrderedSet[sympy.Symbol]()
- for t in s:
- r |= maybe_free_unbacked_symbols(t)
- return r
- elif isinstance(s, torch.Tensor):
- # This branch is impossible in constant-args position
- return free_unbacked_symbols(s)
- else:
- return OrderedSet()
- def maybe_free_symbols(s: object) -> OrderedSet[Symbol]:
- if isinstance(s, (SymTypes, Expr)):
- # This branch should be impossible in return position
- return free_symbols(s)
- elif isinstance(s, (tuple, list)):
- r = OrderedSet[sympy.Symbol]()
- for t in s:
- r |= maybe_free_symbols(t)
- return r
- elif isinstance(s, torch.Tensor):
- # This branch is impossible in constant-args position
- return free_symbols(s)
- else:
- return OrderedSet()
|