modeling_utils.py 302 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350435143524353435443554356435743584359436043614362436343644365436643674368436943704371437243734374437543764377437843794380438143824383438443854386438743884389439043914392439343944395439643974398439944004401440244034404440544064407440844094410441144124413441444154416441744184419442044214422442344244425442644274428442944304431443244334434443544364437443844394440444144424443444444454446444744484449445044514452445344544455445644574458445944604461446244634464446544664467446844694470447144724473447444754476447744784479448044814482448344844485448644874488448944904491449244934494449544964497449844994500450145024503450445054506450745084509451045114512451345144515451645174518451945204521452245234524452545264527452845294530453145324533453445354536453745384539454045414542454345444545454645474548454945504551455245534554455545564557455845594560456145624563456445654566456745684569457045714572457345744575457645774578457945804581458245834584458545864587458845894590459145924593459445954596459745984599460046014602460346044605460646074608460946104611461246134614461546164617461846194620462146224623462446254626462746284629463046314632463346344635463646374638463946404641464246434644464546464647464846494650465146524653465446554656465746584659466046614662466346644665466646674668466946704671467246734674467546764677467846794680468146824683468446854686468746884689469046914692469346944695469646974698469947004701470247034704470547064707470847094710471147124713471447154716471747184719472047214722472347244725472647274728472947304731473247334734473547364737473847394740474147424743474447454746474747484749475047514752475347544755475647574758475947604761476247634764476547664767476847694770477147724773477447754776477747784779478047814782478347844785478647874788478947904791479247934794479547964797479847994800480148024803480448054806480748084809481048114812481348144815481648174818481948204821482248234824482548264827482848294830483148324833483448354836483748384839484048414842484348444845484648474848484948504851485248534854485548564857485848594860486148624863486448654866486748684869487048714872487348744875487648774878487948804881488248834884488548864887488848894890489148924893489448954896489748984899490049014902490349044905490649074908490949104911491249134914491549164917491849194920492149224923492449254926492749284929493049314932493349344935493649374938493949404941494249434944494549464947494849494950495149524953495449554956495749584959496049614962496349644965496649674968496949704971497249734974497549764977497849794980498149824983498449854986498749884989499049914992499349944995499649974998499950005001500250035004500550065007500850095010501150125013501450155016501750185019502050215022502350245025502650275028502950305031503250335034503550365037503850395040504150425043504450455046504750485049505050515052505350545055505650575058505950605061506250635064506550665067506850695070507150725073507450755076507750785079508050815082508350845085508650875088508950905091509250935094509550965097509850995100510151025103510451055106510751085109511051115112511351145115511651175118511951205121512251235124512551265127512851295130513151325133513451355136513751385139514051415142514351445145514651475148514951505151515251535154515551565157515851595160516151625163516451655166516751685169517051715172517351745175517651775178517951805181518251835184518551865187518851895190519151925193519451955196519751985199520052015202520352045205520652075208520952105211521252135214521552165217521852195220522152225223522452255226522752285229523052315232523352345235523652375238523952405241524252435244524552465247524852495250525152525253525452555256525752585259526052615262526352645265526652675268526952705271527252735274527552765277527852795280528152825283528452855286528752885289529052915292529352945295529652975298529953005301530253035304530553065307530853095310531153125313531453155316531753185319532053215322532353245325532653275328532953305331533253335334533553365337533853395340534153425343534453455346534753485349535053515352535353545355535653575358535953605361536253635364536553665367536853695370537153725373537453755376537753785379538053815382538353845385538653875388538953905391539253935394539553965397539853995400540154025403540454055406540754085409541054115412541354145415541654175418541954205421542254235424542554265427542854295430543154325433543454355436543754385439544054415442544354445445544654475448544954505451545254535454545554565457545854595460546154625463546454655466546754685469547054715472547354745475547654775478547954805481548254835484548554865487548854895490549154925493549454955496549754985499550055015502550355045505550655075508550955105511551255135514551555165517551855195520552155225523552455255526552755285529553055315532553355345535553655375538553955405541554255435544554555465547554855495550555155525553555455555556555755585559556055615562556355645565556655675568556955705571557255735574557555765577557855795580558155825583558455855586558755885589559055915592559355945595559655975598559956005601560256035604560556065607560856095610561156125613561456155616561756185619562056215622562356245625562656275628562956305631563256335634563556365637563856395640564156425643564456455646564756485649565056515652565356545655565656575658565956605661566256635664566556665667566856695670567156725673567456755676567756785679568056815682568356845685568656875688568956905691569256935694569556965697569856995700570157025703570457055706570757085709571057115712571357145715571657175718571957205721572257235724572557265727572857295730573157325733573457355736573757385739574057415742574357445745574657475748574957505751575257535754575557565757575857595760576157625763576457655766576757685769577057715772577357745775577657775778577957805781578257835784578557865787578857895790579157925793579457955796579757985799580058015802580358045805580658075808580958105811581258135814581558165817581858195820582158225823582458255826582758285829583058315832583358345835583658375838583958405841584258435844584558465847584858495850585158525853585458555856585758585859586058615862586358645865586658675868586958705871587258735874587558765877587858795880588158825883588458855886588758885889589058915892589358945895589658975898589959005901590259035904590559065907590859095910591159125913591459155916591759185919592059215922592359245925592659275928592959305931593259335934593559365937593859395940594159425943594459455946594759485949595059515952595359545955595659575958595959605961596259635964596559665967596859695970597159725973597459755976597759785979598059815982598359845985598659875988598959905991599259935994599559965997599859996000600160026003600460056006600760086009601060116012601360146015601660176018601960206021602260236024602560266027602860296030603160326033603460356036603760386039604060416042604360446045604660476048604960506051605260536054605560566057605860596060606160626063606460656066606760686069607060716072607360746075607660776078607960806081608260836084608560866087608860896090609160926093609460956096609760986099610061016102610361046105610661076108610961106111611261136114611561166117611861196120612161226123612461256126612761286129613061316132613361346135613661376138613961406141614261436144614561466147614861496150615161526153615461556156615761586159616061616162616361646165
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors, Facebook AI Research authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. import collections
  17. import copy
  18. import functools
  19. import gc
  20. import importlib.metadata
  21. import inspect
  22. import json
  23. import os
  24. import re
  25. import sys
  26. import warnings
  27. from abc import abstractmethod
  28. from collections import defaultdict
  29. from concurrent.futures import ThreadPoolExecutor, as_completed
  30. from contextlib import contextmanager
  31. from enum import Enum
  32. from functools import partial, wraps
  33. from threading import Thread
  34. from typing import Any, Callable, Optional, TypeVar, Union, get_type_hints
  35. from zipfile import is_zipfile
  36. import torch
  37. from huggingface_hub import split_torch_state_dict_into_shards
  38. from packaging import version
  39. from safetensors import safe_open
  40. from safetensors.torch import load_file as safe_load_file
  41. from safetensors.torch import save_file as safe_save_file
  42. from torch import Tensor, nn
  43. from torch.distributions import constraints
  44. from torch.utils.checkpoint import checkpoint
  45. from .configuration_utils import PretrainedConfig
  46. from .distributed import DistributedConfig
  47. from .dynamic_module_utils import custom_object_save
  48. from .generation import CompileConfig, GenerationConfig
  49. from .integrations import PeftAdapterMixin, deepspeed_config, is_deepspeed_zero3_enabled, is_fsdp_enabled
  50. from .integrations.accelerate import find_tied_parameters, init_empty_weights
  51. from .integrations.deepspeed import _load_state_dict_into_zero3_model
  52. from .integrations.eager_paged import eager_paged_attention_forward
  53. from .integrations.flash_attention import flash_attention_forward
  54. from .integrations.flash_paged import paged_attention_forward
  55. from .integrations.flex_attention import flex_attention_forward
  56. from .integrations.hub_kernels import is_kernel, load_and_register_kernel
  57. from .integrations.sdpa_attention import sdpa_attention_forward
  58. from .integrations.sdpa_paged import sdpa_attention_paged_forward
  59. from .integrations.tensor_parallel import (
  60. _get_parameter_tp_plan,
  61. distribute_model,
  62. initialize_tensor_parallelism,
  63. repack_weights,
  64. replace_state_dict_local_with_dtensor,
  65. shard_and_distribute_module,
  66. verify_tp_plan,
  67. )
  68. from .loss.loss_utils import LOSS_MAPPING
  69. from .modeling_flash_attention_utils import lazy_import_flash_attention
  70. from .pytorch_utils import id_tensor_storage
  71. from .quantizers import HfQuantizer
  72. from .quantizers.auto import get_hf_quantizer
  73. from .quantizers.quantizers_utils import get_module_from_name
  74. from .safetensors_conversion import auto_conversion
  75. from .utils import (
  76. ADAPTER_SAFE_WEIGHTS_NAME,
  77. ADAPTER_WEIGHTS_NAME,
  78. CONFIG_NAME,
  79. DUMMY_INPUTS,
  80. FLAX_WEIGHTS_NAME,
  81. SAFE_WEIGHTS_INDEX_NAME,
  82. SAFE_WEIGHTS_NAME,
  83. TF2_WEIGHTS_NAME,
  84. TF_WEIGHTS_NAME,
  85. WEIGHTS_INDEX_NAME,
  86. WEIGHTS_NAME,
  87. ContextManagers,
  88. PushToHubMixin,
  89. cached_file,
  90. check_torch_load_is_safe,
  91. copy_func,
  92. download_url,
  93. extract_commit_hash,
  94. has_file,
  95. is_accelerate_available,
  96. is_bitsandbytes_available,
  97. is_flash_attn_2_available,
  98. is_flash_attn_3_available,
  99. is_kernels_available,
  100. is_offline_mode,
  101. is_optimum_available,
  102. is_peft_available,
  103. is_remote_url,
  104. is_torch_flex_attn_available,
  105. is_torch_greater_or_equal,
  106. is_torch_mlu_available,
  107. is_torch_npu_available,
  108. is_torch_xla_available,
  109. is_torch_xpu_available,
  110. logging,
  111. )
  112. from .utils.generic import _CAN_RECORD_REGISTRY, GeneralInterface, OutputRecorder
  113. from .utils.hub import create_and_tag_model_card, get_checkpoint_shard_files
  114. from .utils.import_utils import (
  115. ENV_VARS_TRUE_VALUES,
  116. is_huggingface_hub_greater_or_equal,
  117. is_sagemaker_mp_enabled,
  118. is_torch_fx_proxy,
  119. is_torchdynamo_compiling,
  120. )
  121. from .utils.quantization_config import BitsAndBytesConfig, QuantizationMethod
  122. if is_accelerate_available():
  123. from accelerate import dispatch_model, infer_auto_device_map
  124. from accelerate.hooks import add_hook_to_module
  125. from accelerate.utils import (
  126. check_tied_parameters_on_same_device,
  127. extract_model_from_parallel,
  128. get_balanced_memory,
  129. get_max_memory,
  130. offload_weight,
  131. save_offload_index,
  132. )
  133. accelerate_version = version.parse(importlib.metadata.version("accelerate"))
  134. if accelerate_version >= version.parse("0.31"):
  135. from accelerate.utils.modeling import get_state_dict_from_offload
  136. if is_peft_available():
  137. from .utils import find_adapter_config_file
  138. _torch_distributed_available = torch.distributed.is_available()
  139. _is_dtensor_available = _torch_distributed_available and is_torch_greater_or_equal("2.5")
  140. if _is_dtensor_available:
  141. from torch.distributed.tensor import DTensor
  142. if is_sagemaker_mp_enabled():
  143. import smdistributed.modelparallel.torch as smp
  144. from smdistributed.modelparallel import __version__ as SMP_VERSION
  145. IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
  146. else:
  147. IS_SAGEMAKER_MP_POST_1_10 = False
  148. logger = logging.get_logger(__name__)
  149. XLA_USE_BF16 = os.environ.get("XLA_USE_BF16", "0").upper()
  150. XLA_DOWNCAST_BF16 = os.environ.get("XLA_DOWNCAST_BF16", "0").upper()
  151. SpecificPreTrainedModelType = TypeVar("SpecificPreTrainedModelType", bound="PreTrainedModel")
  152. _init_weights = True
  153. _is_quantized = False
  154. _is_ds_init_called = False
  155. def is_local_dist_rank_0():
  156. return (
  157. torch.distributed.is_available()
  158. and torch.distributed.is_initialized()
  159. and int(os.environ.get("LOCAL_RANK", "-1")) == 0
  160. )
  161. TORCH_INIT_FUNCTIONS = {
  162. "uniform_": nn.init.uniform_,
  163. "normal_": nn.init.normal_,
  164. "trunc_normal_": nn.init.trunc_normal_,
  165. "constant_": nn.init.constant_,
  166. "xavier_uniform_": nn.init.xavier_uniform_,
  167. "xavier_normal_": nn.init.xavier_normal_,
  168. "kaiming_uniform_": nn.init.kaiming_uniform_,
  169. "kaiming_normal_": nn.init.kaiming_normal_,
  170. "uniform": nn.init.uniform,
  171. "normal": nn.init.normal,
  172. "xavier_uniform": nn.init.xavier_uniform,
  173. "xavier_normal": nn.init.xavier_normal,
  174. "kaiming_uniform": nn.init.kaiming_uniform,
  175. "kaiming_normal": nn.init.kaiming_normal,
  176. }
  177. # DO NOT MODIFY, KEPT FOR BC ONLY
  178. VLMS = [
  179. "aria",
  180. "ayavision",
  181. "colpali",
  182. "emu3",
  183. "fuyu",
  184. "gotocr2",
  185. "gemma3",
  186. "internvl",
  187. "llava", # all llava prefixed models fall under this check
  188. "mistral3",
  189. "mllama",
  190. "paligemma",
  191. "shieldgemma2",
  192. "qwen2vl",
  193. "qwen2_5_vl",
  194. "videollava",
  195. "vipllava",
  196. ]
  197. @contextmanager
  198. def no_init_weights():
  199. """
  200. Context manager to globally disable weight initialization to speed up loading large models.
  201. """
  202. global _init_weights
  203. old_init_weights = _init_weights
  204. _init_weights = False
  205. def _skip_init(*args, **kwargs):
  206. pass
  207. # Save the original initialization functions
  208. for name, init_func in TORCH_INIT_FUNCTIONS.items():
  209. setattr(torch.nn.init, name, _skip_init)
  210. try:
  211. yield
  212. finally:
  213. _init_weights = old_init_weights
  214. # Restore the original initialization functions
  215. for name, init_func in TORCH_INIT_FUNCTIONS.items():
  216. setattr(torch.nn.init, name, init_func)
  217. @contextmanager
  218. def set_quantized_state():
  219. global _is_quantized
  220. _is_quantized = True
  221. try:
  222. yield
  223. finally:
  224. _is_quantized = False
  225. # Skip recursive calls to deepspeed.zero.Init to avoid pinning errors.
  226. # This issue occurs with ZeRO stage 3 when using NVMe offloading.
  227. # For more details, refer to issue #34429.
  228. @contextmanager
  229. def set_zero3_state():
  230. global _is_ds_init_called
  231. _is_ds_init_called = True
  232. try:
  233. yield
  234. finally:
  235. _is_ds_init_called = False
  236. def restore_default_dtype(func):
  237. """
  238. Decorator to restore the default torch dtype
  239. at the end of the function. Serves
  240. as a backup in case calling the function raises
  241. an error after the function has changed the default dtype but before it could restore it.
  242. """
  243. @wraps(func)
  244. def _wrapper(*args, **kwargs):
  245. old_dtype = torch.get_default_dtype()
  246. try:
  247. return func(*args, **kwargs)
  248. finally:
  249. torch.set_default_dtype(old_dtype)
  250. return _wrapper
  251. def get_torch_context_manager_or_global_device():
  252. """
  253. Test if a device context manager is currently in use, or if it is not the case, check if the default device
  254. is not "cpu". This is used to infer the correct device to load the model on, in case `device_map` is not provided.
  255. """
  256. device_in_context = torch.tensor([]).device
  257. # `get_default_device` was only introduced in torch>=2.3 - use cpu otherwise to align the behavior
  258. default_device = torch.get_default_device() if is_torch_greater_or_equal("2.3") else torch.device("cpu")
  259. # This case means no context manager was used -> we still check if the default that was potentially set is not cpu
  260. if device_in_context == default_device:
  261. if default_device != torch.device("cpu"):
  262. return default_device
  263. return None
  264. return device_in_context
  265. def get_parameter_device(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
  266. try:
  267. return next(parameter.parameters()).device
  268. except StopIteration:
  269. # For nn.DataParallel compatibility in PyTorch 1.5
  270. def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
  271. tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
  272. return tuples
  273. gen = parameter._named_members(get_members_fn=find_tensor_attributes)
  274. first_tuple = next(gen)
  275. return first_tuple[1].device
  276. def get_parameter_dtype(parameter: Union[nn.Module, "ModuleUtilsMixin"]):
  277. """
  278. Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
  279. """
  280. last_dtype = None
  281. for t in parameter.parameters():
  282. last_dtype = t.dtype
  283. if t.is_floating_point():
  284. # Adding fix for https://github.com/pytorch/xla/issues/4152
  285. # Fixes issue where the model code passes a value that is out of range for XLA_USE_BF16=1
  286. # and XLA_DOWNCAST_BF16=1 so the conversion would cast it to -inf
  287. # NOTE: `is_torch_xla_available()` is checked last as it induces a graph break in torch dynamo
  288. if XLA_USE_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
  289. return torch.bfloat16
  290. if XLA_DOWNCAST_BF16 in ENV_VARS_TRUE_VALUES and is_torch_xla_available():
  291. if t.dtype == torch.float:
  292. return torch.bfloat16
  293. if t.dtype == torch.double:
  294. return torch.float32
  295. return t.dtype
  296. if last_dtype is not None:
  297. # if no floating dtype was found return whatever the first dtype is
  298. return last_dtype
  299. # For nn.DataParallel compatibility in PyTorch > 1.5
  300. def find_tensor_attributes(module: nn.Module) -> list[tuple[str, Tensor]]:
  301. tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
  302. return tuples
  303. gen = parameter._named_members(get_members_fn=find_tensor_attributes)
  304. last_tuple = None
  305. for gen_tuple in gen:
  306. last_tuple = gen_tuple
  307. if gen_tuple[1].is_floating_point():
  308. return gen_tuple[1].dtype
  309. if last_tuple is not None:
  310. # fallback to the last dtype
  311. return last_tuple[1].dtype
  312. # fallback to buffer dtype
  313. for t in parameter.buffers():
  314. last_dtype = t.dtype
  315. if t.is_floating_point():
  316. return t.dtype
  317. return last_dtype
  318. def get_state_dict_dtype(state_dict):
  319. """
  320. Returns the first found floating dtype in `state_dict` if there is one, otherwise returns the first dtype.
  321. """
  322. for t in state_dict.values():
  323. if t.is_floating_point():
  324. return t.dtype
  325. # if no floating dtype was found return whatever the first dtype is
  326. return next(state_dict.values()).dtype
  327. def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
  328. """
  329. This is the same as
  330. [`torch.nn.Module.load_state_dict`](https://pytorch.org/docs/stable/generated/torch.nn.Module.html?highlight=load_state_dict#torch.nn.Module.load_state_dict)
  331. but for a sharded checkpoint.
  332. This load is performed efficiently: each checkpoint shard is loaded one by one in RAM and deleted after being
  333. loaded in the model.
  334. Args:
  335. model (`torch.nn.Module`): The model in which to load the checkpoint.
  336. folder (`str` or `os.PathLike`): A path to a folder containing the sharded checkpoint.
  337. strict (`bool`, *optional*, defaults to `True`):
  338. Whether to strictly enforce that the keys in the model state dict match the keys in the sharded checkpoint.
  339. prefer_safe (`bool`, *optional*, defaults to `False`):
  340. If both safetensors and PyTorch save files are present in checkpoint and `prefer_safe` is True, the
  341. safetensors files will be loaded. Otherwise, PyTorch files are always loaded when possible.
  342. Returns:
  343. `NamedTuple`: A named tuple with `missing_keys` and `unexpected_keys` fields
  344. - `missing_keys` is a list of str containing the missing keys
  345. - `unexpected_keys` is a list of str containing the unexpected keys
  346. """
  347. # Load the index
  348. index_file = os.path.join(folder, WEIGHTS_INDEX_NAME)
  349. safe_index_file = os.path.join(folder, SAFE_WEIGHTS_INDEX_NAME)
  350. index_present = os.path.isfile(index_file)
  351. safe_index_present = os.path.isfile(safe_index_file)
  352. if not index_present and not safe_index_present:
  353. filenames = (WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_INDEX_NAME)
  354. raise ValueError(f"Can't find a checkpoint index ({' or '.join(filenames)}) in {folder}.")
  355. load_safe = safe_index_present and (prefer_safe or not index_present)
  356. load_index = safe_index_file if load_safe else index_file
  357. with open(load_index, "r", encoding="utf-8") as f:
  358. index = json.load(f)
  359. shard_files = list(set(index["weight_map"].values()))
  360. # If strict=True, error before loading any of the state dicts.
  361. loaded_keys = index["weight_map"].keys()
  362. model_keys = model.state_dict().keys()
  363. missing_keys = [key for key in model_keys if key not in loaded_keys]
  364. unexpected_keys = [key for key in loaded_keys if key not in model_keys]
  365. if strict and (len(missing_keys) > 0 or len(unexpected_keys) > 0):
  366. error_message = f"Error(s) in loading state_dict for {model.__class__.__name__}"
  367. if len(missing_keys) > 0:
  368. str_missing_keys = ",".join([f'"{k}"' for k in missing_keys])
  369. error_message += f"\nMissing key(s): {str_missing_keys}."
  370. if len(unexpected_keys) > 0:
  371. str_unexpected_keys = ",".join([f'"{k}"' for k in unexpected_keys])
  372. error_message += f"\nMissing key(s): {str_unexpected_keys}."
  373. raise RuntimeError(error_message)
  374. if load_safe:
  375. loader = safe_load_file
  376. else:
  377. check_torch_load_is_safe()
  378. loader = partial(torch.load, map_location="cpu", weights_only=True)
  379. for shard_file in shard_files:
  380. state_dict = loader(os.path.join(folder, shard_file))
  381. model.load_state_dict(state_dict, strict=False)
  382. # Make sure memory is freed before we load the next state dict.
  383. del state_dict
  384. gc.collect()
  385. # Return the same thing as PyTorch load_state_dict function.
  386. return torch.nn.modules.module._IncompatibleKeys(missing_keys, unexpected_keys)
  387. str_to_torch_dtype = {
  388. "BOOL": torch.bool,
  389. "U8": torch.uint8,
  390. "I8": torch.int8,
  391. "I16": torch.int16,
  392. "F16": torch.float16,
  393. "BF16": torch.bfloat16,
  394. "I32": torch.int32,
  395. "F32": torch.float32,
  396. "F64": torch.float64,
  397. "I64": torch.int64,
  398. "F8_E4M3": torch.float8_e4m3fn,
  399. "F8_E5M2": torch.float8_e5m2,
  400. }
  401. if is_torch_greater_or_equal("2.3.0"):
  402. str_to_torch_dtype["U16"] = torch.uint16
  403. str_to_torch_dtype["U32"] = torch.uint32
  404. str_to_torch_dtype["U64"] = torch.uint64
  405. def load_state_dict(
  406. checkpoint_file: Union[str, os.PathLike],
  407. is_quantized: bool = False,
  408. map_location: Optional[Union[str, torch.device]] = "cpu",
  409. weights_only: bool = True,
  410. ):
  411. """
  412. Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
  413. """
  414. # Use safetensors if possible
  415. if checkpoint_file.endswith(".safetensors"):
  416. with safe_open(checkpoint_file, framework="pt") as f:
  417. metadata = f.metadata()
  418. if metadata is not None and metadata.get("format") not in ["pt", "tf", "flax", "mlx"]:
  419. raise OSError(
  420. f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
  421. "you save your model with the `save_pretrained` method."
  422. )
  423. state_dict = {}
  424. for k in f.keys():
  425. if map_location == "meta":
  426. _slice = f.get_slice(k)
  427. k_dtype = _slice.get_dtype()
  428. if k_dtype in str_to_torch_dtype:
  429. dtype = str_to_torch_dtype[k_dtype]
  430. else:
  431. raise ValueError(f"Cannot load safetensors of unknown dtype {k_dtype}")
  432. state_dict[k] = torch.empty(size=_slice.get_shape(), dtype=dtype, device="meta")
  433. else:
  434. state_dict[k] = f.get_tensor(k)
  435. return state_dict
  436. # Fallback to torch.load (if weights_only was explicitly False, do not check safety as this is known to be unsafe)
  437. if weights_only:
  438. check_torch_load_is_safe()
  439. try:
  440. if map_location is None:
  441. if (
  442. (
  443. is_deepspeed_zero3_enabled()
  444. and torch.distributed.is_initialized()
  445. and torch.distributed.get_rank() > 0
  446. )
  447. or (is_fsdp_enabled() and not is_local_dist_rank_0())
  448. ) and not is_quantized:
  449. map_location = "meta"
  450. else:
  451. map_location = "cpu"
  452. extra_args = {}
  453. # mmap can only be used with files serialized with zipfile-based format.
  454. if isinstance(checkpoint_file, str) and map_location != "meta" and is_zipfile(checkpoint_file):
  455. extra_args = {"mmap": True}
  456. return torch.load(
  457. checkpoint_file,
  458. map_location=map_location,
  459. weights_only=weights_only,
  460. **extra_args,
  461. )
  462. except Exception as e:
  463. try:
  464. with open(checkpoint_file) as f:
  465. if f.read(7) == "version":
  466. raise OSError(
  467. "You seem to have cloned a repository without having git-lfs installed. Please install "
  468. "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
  469. "you cloned."
  470. )
  471. else:
  472. raise ValueError(
  473. f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
  474. "model. Make sure you have saved the model properly."
  475. ) from e
  476. except (UnicodeDecodeError, ValueError):
  477. raise OSError(
  478. f"Unable to load weights from pytorch checkpoint file for '{checkpoint_file}' "
  479. f"at '{checkpoint_file}'. "
  480. "If you tried to load a PyTorch model from a TF 2.0 checkpoint, please set from_tf=True."
  481. )
  482. def _end_ptr(tensor: torch.Tensor) -> int:
  483. # extract the end of the pointer if the tensor is a slice of a bigger tensor
  484. if tensor.nelement():
  485. stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
  486. else:
  487. stop = tensor.data_ptr()
  488. return stop
  489. def _get_tied_weight_keys(module: nn.Module, prefix=""):
  490. tied_weight_keys = []
  491. if getattr(module, "_tied_weights_keys", None) is not None:
  492. names = [f"{prefix}.{k}" if prefix else k for k in module._tied_weights_keys]
  493. tied_weight_keys.extend(names)
  494. if getattr(module, "_dynamic_tied_weights_keys", None) is not None:
  495. names = [f"{prefix}.{k}" if prefix else k for k in module._dynamic_tied_weights_keys]
  496. tied_weight_keys.extend(names)
  497. for name, submodule in module.named_children():
  498. local_prefix = f"{prefix}.{name}" if prefix else name
  499. tied_weight_keys.extend(_get_tied_weight_keys(submodule, prefix=local_prefix))
  500. return tied_weight_keys
  501. def _find_disjoint(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]) -> tuple[list[set[str]], list[str]]:
  502. filtered_tensors = []
  503. for shared in tensors:
  504. if len(shared) < 2:
  505. filtered_tensors.append(shared)
  506. continue
  507. areas = []
  508. for name in shared:
  509. tensor = state_dict[name]
  510. areas.append((tensor.data_ptr(), _end_ptr(tensor), name))
  511. areas.sort()
  512. _, last_stop, last_name = areas[0]
  513. filtered_tensors.append({last_name})
  514. for start, stop, name in areas[1:]:
  515. if start >= last_stop:
  516. filtered_tensors.append({name})
  517. else:
  518. filtered_tensors[-1].add(name)
  519. last_stop = stop
  520. disjoint_tensors = []
  521. shared_tensors = []
  522. for tensors in filtered_tensors:
  523. if len(tensors) == 1:
  524. disjoint_tensors.append(tensors.pop())
  525. else:
  526. shared_tensors.append(tensors)
  527. return shared_tensors, disjoint_tensors
  528. def _find_identical(tensors: list[set[str]], state_dict: dict[str, torch.Tensor]) -> tuple[list[set[str]], set[str]]:
  529. shared_tensors = []
  530. identical = []
  531. for shared in tensors:
  532. if len(shared) < 2:
  533. continue
  534. areas = collections.defaultdict(set)
  535. for name in shared:
  536. tensor = state_dict[name]
  537. area = (tensor.device, tensor.data_ptr(), _end_ptr(tensor))
  538. areas[area].add(name)
  539. if len(areas) == 1:
  540. identical.append(shared)
  541. else:
  542. shared_tensors.append(shared)
  543. return shared_tensors, identical
  544. def _infer_parameter_dtype(
  545. model: "PreTrainedModel",
  546. param_name: str,
  547. empty_param: torch.Tensor,
  548. keep_in_fp32_regex: Optional[re.Pattern] = None,
  549. hf_quantizer: Optional[HfQuantizer] = None,
  550. ) -> Union[bool, Optional[torch.dtype]]:
  551. try:
  552. old_param = model.get_parameter_or_buffer(param_name)
  553. except Exception as e:
  554. if hf_quantizer is not None and hf_quantizer.quantization_config.quant_method in {
  555. QuantizationMethod.HQQ,
  556. QuantizationMethod.QUARK,
  557. QuantizationMethod.MXFP4,
  558. QuantizationMethod.BITS_AND_BYTES,
  559. }:
  560. return True, None
  561. else:
  562. raise e
  563. is_torch_e4m3fn_available = hasattr(torch, "float8_e4m3fn")
  564. # We convert floating dtypes to the `dtype` passed except for float8_e4m3fn type. We also want to keep the buffers/params
  565. # in int/uint/bool and not cast them.
  566. casting_dtype = None
  567. is_param_float8_e4m3fn = is_torch_e4m3fn_available and empty_param.dtype == torch.float8_e4m3fn
  568. if empty_param.dtype.is_floating_point and not is_param_float8_e4m3fn:
  569. # First fp32 if part of the exception list
  570. if keep_in_fp32_regex is not None and keep_in_fp32_regex.search(param_name):
  571. casting_dtype = torch.float32
  572. # Then dtype that was instantiated in the meta model -- note that this respects subconfigs dtypes
  573. elif hf_quantizer is not None:
  574. casting_dtype = model.config._pre_quantization_dtype
  575. else:
  576. casting_dtype = old_param.dtype
  577. return old_param is not None and old_param.is_contiguous(), casting_dtype
  578. def _load_parameter_into_model(model: "PreTrainedModel", param_name: str, tensor: torch.Tensor):
  579. """Cast a single parameter `param_name` into the `model`, with value `tensor`."""
  580. module, param_type = get_module_from_name(model, param_name)
  581. # This will check potential shape mismatch if skipped before
  582. module.load_state_dict({param_type: tensor}, strict=False, assign=True)
  583. @torch.no_grad()
  584. def _load_state_dict_into_meta_model(
  585. model: "PreTrainedModel",
  586. state_dict: dict,
  587. shard_file: str,
  588. reverse_renaming_mapping: dict[str, str],
  589. device_map: Optional[dict] = None,
  590. disk_offload_folder: Optional[str] = None,
  591. disk_offload_index: Optional[dict] = None,
  592. hf_quantizer: Optional[HfQuantizer] = None,
  593. keep_in_fp32_regex: Optional[re.Pattern] = None,
  594. device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
  595. ) -> tuple[Optional[dict], Optional[dict]]:
  596. """Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
  597. device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
  598. from `shard_file`, which is the actual state dict file on disk.
  599. This function takes care of correctly casting dtypes, devices, and sharding tensors in case of tensor parallelism.
  600. """
  601. tensor_device = "cpu"
  602. if device_map is not None and device_map.get("", None) is not None:
  603. if device_map[""] not in ("cpu", torch.device("cpu")):
  604. tensor_device = device_map[""].index if isinstance(device_map[""], torch.device) else device_map[""]
  605. if device_map is not None:
  606. device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
  607. is_quantized = hf_quantizer is not None
  608. is_safetensors = shard_file.endswith(".safetensors")
  609. is_meta_state_dict = is_safetensors
  610. file_pointer = safe_open(shard_file, framework="pt", device=tensor_device) if is_meta_state_dict else None
  611. params_to_load = list(state_dict.keys())
  612. for param_name in params_to_load:
  613. empty_param = state_dict[param_name]
  614. # we need to use serialized_param_name as file pointer is untouched
  615. if is_meta_state_dict:
  616. # This is the name of the parameter as it appears on disk file
  617. serialized_param_name = reverse_renaming_mapping[param_name]
  618. param = file_pointer.get_slice(serialized_param_name)
  619. else:
  620. param = empty_param.to(tensor_device) # It is actually not empty!
  621. to_contiguous, casting_dtype = _infer_parameter_dtype(
  622. model,
  623. param_name,
  624. empty_param,
  625. keep_in_fp32_regex,
  626. hf_quantizer,
  627. )
  628. if device_mesh is not None:
  629. if not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
  630. # In this case, the param is already on the correct device!
  631. shard_and_distribute_module(
  632. model,
  633. param,
  634. empty_param,
  635. param_name,
  636. casting_dtype,
  637. to_contiguous,
  638. device_mesh.get_local_rank(),
  639. device_mesh,
  640. )
  641. else:
  642. # we have a device mesh but the param needs to be quantized, so we shard inside create_quantized_param
  643. sharding_kwargs = {
  644. "empty_param": empty_param,
  645. "casting_dtype": casting_dtype,
  646. "to_contiguous": to_contiguous,
  647. "rank": device_mesh.get_local_rank(),
  648. "device_mesh": device_mesh,
  649. }
  650. hf_quantizer.create_quantized_param(
  651. model,
  652. param,
  653. param_name,
  654. device_mesh.get_local_rank(),
  655. **sharding_kwargs,
  656. )
  657. else:
  658. param = param[...]
  659. if casting_dtype is not None:
  660. param = param.to(casting_dtype)
  661. if to_contiguous:
  662. param = param.contiguous()
  663. if device_map is None:
  664. param_device = "cpu"
  665. else:
  666. module_layer = re.search(device_map_regex, param_name)
  667. if not module_layer:
  668. raise ValueError(f"{param_name} doesn't have any device set.")
  669. else:
  670. param_device = device_map[module_layer.group()]
  671. if param_device == "disk":
  672. if not is_safetensors:
  673. disk_offload_index = offload_weight(param, param_name, disk_offload_folder, disk_offload_index)
  674. elif not is_quantized or not hf_quantizer.param_needs_quantization(model, param_name):
  675. if is_fsdp_enabled():
  676. param_device = "cpu" if is_local_dist_rank_0() else "meta"
  677. _load_parameter_into_model(model, param_name, param.to(param_device))
  678. else:
  679. # TODO naming is stupid it loads it as well
  680. hf_quantizer.create_quantized_param(model, param, param_name, param_device)
  681. # For quantized modules with FSDP/DeepSpeed Stage 3, we need to quantize the parameter on the GPU
  682. # and then cast it to CPU to avoid excessive memory usage on each GPU
  683. # in comparison to the sharded model across GPUs.
  684. if is_fsdp_enabled() or is_deepspeed_zero3_enabled():
  685. param_name = hf_quantizer.get_param_name(param_name)
  686. module, param_type = get_module_from_name(model, param_name)
  687. value = getattr(module, param_type)
  688. # We need to wait until the quantized value is created
  689. if value.device.type == "meta":
  690. continue
  691. val_kwargs = value.__dict__
  692. if not value.is_floating_point():
  693. val_kwargs["requires_grad"] = False
  694. device = "meta" if is_fsdp_enabled() and not is_local_dist_rank_0() else "cpu"
  695. value = type(value)(value.data.to(device), **val_kwargs)
  696. setattr(module, param_type, value)
  697. # Remove the param from the state dict if it was not loaded on the fly to avoid wasting memory
  698. if not is_meta_state_dict:
  699. del state_dict[param_name]
  700. if file_pointer is not None:
  701. file_pointer.__exit__(None, None, None)
  702. return disk_offload_index
  703. def load_shard_file(args):
  704. (
  705. shard_file,
  706. state_dict,
  707. disk_only_shard_files,
  708. is_quantized,
  709. device_map,
  710. hf_quantizer,
  711. key_renaming_mapping,
  712. weights_only,
  713. model,
  714. reverse_key_renaming_mapping,
  715. disk_offload_folder,
  716. disk_offload_index,
  717. keep_in_fp32_regex,
  718. device_mesh,
  719. ) = args
  720. # Skip the load for shards that only contain disk-offloaded weights
  721. if shard_file in disk_only_shard_files:
  722. return [], disk_offload_index
  723. map_location = "cpu"
  724. if shard_file.endswith(".safetensors") and not (is_deepspeed_zero3_enabled() and not is_quantized):
  725. map_location = "meta"
  726. # If shard_file is "", we use the existing state_dict instead of loading it
  727. if shard_file != "":
  728. state_dict = load_state_dict(
  729. shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
  730. )
  731. # Fix the key names
  732. state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
  733. error_msgs = []
  734. if is_deepspeed_zero3_enabled() and not is_quantized:
  735. error_msgs += _load_state_dict_into_zero3_model(model, state_dict)
  736. # Skip it with fsdp on ranks other than 0
  737. elif not (is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized):
  738. disk_offload_index = _load_state_dict_into_meta_model(
  739. model,
  740. state_dict,
  741. shard_file,
  742. reverse_key_renaming_mapping,
  743. device_map=device_map,
  744. disk_offload_folder=disk_offload_folder,
  745. disk_offload_index=disk_offload_index,
  746. hf_quantizer=hf_quantizer,
  747. keep_in_fp32_regex=keep_in_fp32_regex,
  748. device_mesh=device_mesh,
  749. )
  750. return error_msgs, disk_offload_index
  751. def load_shard_files_with_threadpool(args_list):
  752. num_workers = int(os.environ.get("HF_PARALLEL_LOADING_WORKERS", "8"))
  753. # Do not spawn anymore workers than you need
  754. num_workers = min(len(args_list), num_workers)
  755. logger.info(f"Loading model weights in parallel with {num_workers} workers...")
  756. error_msgs = []
  757. with ThreadPoolExecutor(max_workers=num_workers) as executor:
  758. with logging.tqdm(total=len(args_list), desc="Loading checkpoint shards") as pbar:
  759. futures = [executor.submit(load_shard_file, arg) for arg in args_list]
  760. for future in as_completed(futures):
  761. _error_msgs, disk_offload_index = future.result()
  762. error_msgs += _error_msgs
  763. pbar.update(1)
  764. return error_msgs, disk_offload_index
  765. def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
  766. if variant is not None:
  767. path, name = weights_name.rsplit(".", 1)
  768. weights_name = f"{path}.{variant}.{name}"
  769. return weights_name
  770. def _get_resolved_checkpoint_files(
  771. pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
  772. subfolder: str,
  773. variant: Optional[str],
  774. gguf_file: Optional[str],
  775. from_tf: bool,
  776. from_flax: bool,
  777. use_safetensors: bool,
  778. cache_dir: str,
  779. force_download: bool,
  780. proxies: Optional[dict[str, str]],
  781. local_files_only: bool,
  782. token: Optional[Union[str, bool]],
  783. user_agent: dict,
  784. revision: str,
  785. commit_hash: Optional[str],
  786. is_remote_code: bool, # Because we can't determine this inside this function, we need it to be passed in
  787. transformers_explicit_filename: Optional[str] = None,
  788. ) -> tuple[Optional[list[str]], Optional[dict]]:
  789. """Get all the checkpoint filenames based on `pretrained_model_name_or_path`, and optional metadata if the
  790. checkpoints are sharded.
  791. This function will download the data if necessary.
  792. """
  793. is_sharded = False
  794. if pretrained_model_name_or_path is not None and gguf_file is None:
  795. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  796. is_local = os.path.isdir(pretrained_model_name_or_path)
  797. if is_local:
  798. if transformers_explicit_filename is not None:
  799. # If the filename is explicitly defined, load this by default.
  800. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, transformers_explicit_filename)
  801. is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
  802. elif from_tf and os.path.isfile(
  803. os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
  804. ):
  805. # Load from a TF 1.0 checkpoint in priority if from_tf
  806. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index")
  807. elif from_tf and os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)):
  808. # Load from a TF 2.0 checkpoint in priority if from_tf
  809. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME)
  810. elif from_flax and os.path.isfile(
  811. os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
  812. ):
  813. # Load from a Flax checkpoint in priority if from_flax
  814. archive_file = os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
  815. elif use_safetensors is not False and os.path.isfile(
  816. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant))
  817. ):
  818. # Load from a safetensors checkpoint
  819. archive_file = os.path.join(
  820. pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_NAME, variant)
  821. )
  822. elif use_safetensors is not False and os.path.isfile(
  823. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant))
  824. ):
  825. # Load from a sharded safetensors checkpoint
  826. archive_file = os.path.join(
  827. pretrained_model_name_or_path, subfolder, _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)
  828. )
  829. is_sharded = True
  830. elif not use_safetensors and os.path.isfile(
  831. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant))
  832. ):
  833. # Load from a PyTorch checkpoint
  834. archive_file = os.path.join(
  835. pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_NAME, variant)
  836. )
  837. elif not use_safetensors and os.path.isfile(
  838. os.path.join(pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant))
  839. ):
  840. # Load from a sharded PyTorch checkpoint
  841. archive_file = os.path.join(
  842. pretrained_model_name_or_path, subfolder, _add_variant(WEIGHTS_INDEX_NAME, variant)
  843. )
  844. is_sharded = True
  845. # At this stage we don't have a weight file so we will raise an error.
  846. elif not use_safetensors and (
  847. os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF_WEIGHTS_NAME + ".index"))
  848. or os.path.isfile(os.path.join(pretrained_model_name_or_path, subfolder, TF2_WEIGHTS_NAME))
  849. ):
  850. raise OSError(
  851. f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
  852. f" {pretrained_model_name_or_path} but there is a file for TensorFlow weights. Use"
  853. " `from_tf=True` to load this model from those weights."
  854. )
  855. elif not use_safetensors and os.path.isfile(
  856. os.path.join(pretrained_model_name_or_path, subfolder, FLAX_WEIGHTS_NAME)
  857. ):
  858. raise OSError(
  859. f"Error no file named {_add_variant(WEIGHTS_NAME, variant)} found in directory"
  860. f" {pretrained_model_name_or_path} but there is a file for Flax weights. Use `from_flax=True`"
  861. " to load this model from those weights."
  862. )
  863. elif use_safetensors:
  864. raise OSError(
  865. f"Error no file named {_add_variant(SAFE_WEIGHTS_NAME, variant)} found in directory"
  866. f" {pretrained_model_name_or_path}."
  867. )
  868. else:
  869. raise OSError(
  870. f"Error no file named {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
  871. f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME + '.index'} or {FLAX_WEIGHTS_NAME} found in directory"
  872. f" {pretrained_model_name_or_path}."
  873. )
  874. elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path)):
  875. archive_file = pretrained_model_name_or_path
  876. is_local = True
  877. elif os.path.isfile(os.path.join(subfolder, pretrained_model_name_or_path + ".index")):
  878. if not from_tf:
  879. raise ValueError(
  880. f"We found a TensorFlow checkpoint at {pretrained_model_name_or_path + '.index'}, please set "
  881. "from_tf to True to load from this checkpoint."
  882. )
  883. archive_file = os.path.join(subfolder, pretrained_model_name_or_path + ".index")
  884. is_local = True
  885. elif is_remote_url(pretrained_model_name_or_path):
  886. filename = pretrained_model_name_or_path
  887. resolved_archive_file = download_url(pretrained_model_name_or_path)
  888. else:
  889. # set correct filename
  890. if transformers_explicit_filename is not None:
  891. filename = transformers_explicit_filename
  892. is_sharded = transformers_explicit_filename.endswith(".safetensors.index.json")
  893. elif from_tf:
  894. filename = TF2_WEIGHTS_NAME
  895. elif from_flax:
  896. filename = FLAX_WEIGHTS_NAME
  897. elif use_safetensors is not False:
  898. filename = _add_variant(SAFE_WEIGHTS_NAME, variant)
  899. else:
  900. filename = _add_variant(WEIGHTS_NAME, variant)
  901. try:
  902. # Load from URL or cache if already cached
  903. cached_file_kwargs = {
  904. "cache_dir": cache_dir,
  905. "force_download": force_download,
  906. "proxies": proxies,
  907. "local_files_only": local_files_only,
  908. "token": token,
  909. "user_agent": user_agent,
  910. "revision": revision,
  911. "subfolder": subfolder,
  912. "_raise_exceptions_for_gated_repo": False,
  913. "_raise_exceptions_for_missing_entries": False,
  914. "_commit_hash": commit_hash,
  915. }
  916. resolved_archive_file = cached_file(pretrained_model_name_or_path, filename, **cached_file_kwargs)
  917. # Since we set _raise_exceptions_for_missing_entries=False, we don't get an exception but a None
  918. # result when internet is up, the repo and revision exist, but the file does not.
  919. if resolved_archive_file is None and filename == _add_variant(SAFE_WEIGHTS_NAME, variant):
  920. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  921. resolved_archive_file = cached_file(
  922. pretrained_model_name_or_path,
  923. _add_variant(SAFE_WEIGHTS_INDEX_NAME, variant),
  924. **cached_file_kwargs,
  925. )
  926. if resolved_archive_file is not None:
  927. is_sharded = True
  928. elif use_safetensors:
  929. if revision == "main":
  930. resolved_archive_file, revision, is_sharded = auto_conversion(
  931. pretrained_model_name_or_path, **cached_file_kwargs
  932. )
  933. cached_file_kwargs["revision"] = revision
  934. if resolved_archive_file is None:
  935. raise OSError(
  936. f"{pretrained_model_name_or_path} does not appear to have a file named"
  937. f" {_add_variant(SAFE_WEIGHTS_NAME, variant)} or {_add_variant(SAFE_WEIGHTS_INDEX_NAME, variant)} "
  938. "and thus cannot be loaded with `safetensors`. Please make sure that the model has "
  939. "been saved with `safe_serialization=True` or do not set `use_safetensors=True`."
  940. )
  941. else:
  942. # This repo has no safetensors file of any kind, we switch to PyTorch.
  943. filename = _add_variant(WEIGHTS_NAME, variant)
  944. resolved_archive_file = cached_file(
  945. pretrained_model_name_or_path, filename, **cached_file_kwargs
  946. )
  947. if resolved_archive_file is None and filename == _add_variant(WEIGHTS_NAME, variant):
  948. # Maybe the checkpoint is sharded, we try to grab the index name in this case.
  949. resolved_archive_file = cached_file(
  950. pretrained_model_name_or_path,
  951. _add_variant(WEIGHTS_INDEX_NAME, variant),
  952. **cached_file_kwargs,
  953. )
  954. if resolved_archive_file is not None:
  955. is_sharded = True
  956. if not local_files_only and not is_offline_mode():
  957. if resolved_archive_file is not None:
  958. # In a CI environment (CircleCI / Github Actions workflow runs) or in a pytest run,
  959. # we set `DISABLE_SAFETENSORS_CONVERSION=true` to prevent the conversion.
  960. if (
  961. filename in [WEIGHTS_NAME, WEIGHTS_INDEX_NAME]
  962. and os.getenv("DISABLE_SAFETENSORS_CONVERSION", None) != "true"
  963. ):
  964. # If the PyTorch file was found, check if there is a safetensors file on the repository
  965. # If there is no safetensors file on the repositories, start an auto conversion
  966. safe_weights_name = SAFE_WEIGHTS_INDEX_NAME if is_sharded else SAFE_WEIGHTS_NAME
  967. has_file_kwargs = {
  968. "revision": revision,
  969. "proxies": proxies,
  970. "token": token,
  971. "cache_dir": cache_dir,
  972. "local_files_only": local_files_only,
  973. }
  974. cached_file_kwargs = {
  975. "cache_dir": cache_dir,
  976. "force_download": force_download,
  977. "local_files_only": local_files_only,
  978. "user_agent": user_agent,
  979. "subfolder": subfolder,
  980. "_raise_exceptions_for_gated_repo": False,
  981. "_raise_exceptions_for_missing_entries": False,
  982. "_commit_hash": commit_hash,
  983. **has_file_kwargs,
  984. }
  985. if (
  986. not has_file(pretrained_model_name_or_path, safe_weights_name, **has_file_kwargs)
  987. and not is_remote_code
  988. ):
  989. Thread(
  990. target=auto_conversion,
  991. args=(pretrained_model_name_or_path,),
  992. kwargs={"ignore_errors_during_conversion": True, **cached_file_kwargs},
  993. name="Thread-auto_conversion",
  994. ).start()
  995. else:
  996. # Otherwise, no PyTorch file was found, maybe there is a TF or Flax model file.
  997. # We try those to give a helpful error message.
  998. has_file_kwargs = {
  999. "revision": revision,
  1000. "proxies": proxies,
  1001. "token": token,
  1002. "cache_dir": cache_dir,
  1003. "local_files_only": local_files_only,
  1004. }
  1005. if has_file(pretrained_model_name_or_path, TF2_WEIGHTS_NAME, **has_file_kwargs):
  1006. raise OSError(
  1007. f"{pretrained_model_name_or_path} does not appear to have a file named"
  1008. f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for TensorFlow weights."
  1009. " Use `from_tf=True` to load this model from those weights."
  1010. )
  1011. elif has_file(pretrained_model_name_or_path, FLAX_WEIGHTS_NAME, **has_file_kwargs):
  1012. raise OSError(
  1013. f"{pretrained_model_name_or_path} does not appear to have a file named"
  1014. f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file for Flax weights. Use"
  1015. " `from_flax=True` to load this model from those weights."
  1016. )
  1017. elif variant is not None and has_file(
  1018. pretrained_model_name_or_path, WEIGHTS_NAME, **has_file_kwargs
  1019. ):
  1020. raise OSError(
  1021. f"{pretrained_model_name_or_path} does not appear to have a file named"
  1022. f" {_add_variant(WEIGHTS_NAME, variant)} but there is a file without the variant"
  1023. f" {variant}. Use `variant=None` to load this model from those weights."
  1024. )
  1025. else:
  1026. raise OSError(
  1027. f"{pretrained_model_name_or_path} does not appear to have a file named"
  1028. f" {_add_variant(WEIGHTS_NAME, variant)}, {_add_variant(SAFE_WEIGHTS_NAME, variant)},"
  1029. f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
  1030. )
  1031. except OSError:
  1032. # Raise any environment error raise by `cached_file`. It will have a helpful error message adapted
  1033. # to the original exception.
  1034. raise
  1035. except Exception as e:
  1036. # For any other exception, we throw a generic error.
  1037. raise OSError(
  1038. f"Can't load the model for '{pretrained_model_name_or_path}'. If you were trying to load it"
  1039. " from 'https://huggingface.co/models', make sure you don't have a local directory with the"
  1040. f" same name. Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a"
  1041. f" directory containing a file named {_add_variant(WEIGHTS_NAME, variant)},"
  1042. f" {TF2_WEIGHTS_NAME}, {TF_WEIGHTS_NAME} or {FLAX_WEIGHTS_NAME}."
  1043. ) from e
  1044. if is_local:
  1045. logger.info(f"loading weights file {archive_file}")
  1046. resolved_archive_file = archive_file
  1047. else:
  1048. logger.info(f"loading weights file {filename} from cache at {resolved_archive_file}")
  1049. elif gguf_file:
  1050. # Case 1: the GGUF file is present locally
  1051. if os.path.isfile(gguf_file):
  1052. resolved_archive_file = gguf_file
  1053. # Case 2: The GGUF path is a location on the Hub
  1054. # Load from URL or cache if already cached
  1055. else:
  1056. cached_file_kwargs = {
  1057. "cache_dir": cache_dir,
  1058. "force_download": force_download,
  1059. "proxies": proxies,
  1060. "local_files_only": local_files_only,
  1061. "token": token,
  1062. "user_agent": user_agent,
  1063. "revision": revision,
  1064. "subfolder": subfolder,
  1065. "_raise_exceptions_for_gated_repo": False,
  1066. "_raise_exceptions_for_missing_entries": False,
  1067. "_commit_hash": commit_hash,
  1068. }
  1069. resolved_archive_file = cached_file(pretrained_model_name_or_path, gguf_file, **cached_file_kwargs)
  1070. # We now download and resolve all checkpoint files if the checkpoint is sharded
  1071. sharded_metadata = None
  1072. if is_sharded:
  1073. checkpoint_files, sharded_metadata = get_checkpoint_shard_files(
  1074. pretrained_model_name_or_path,
  1075. resolved_archive_file,
  1076. cache_dir=cache_dir,
  1077. force_download=force_download,
  1078. proxies=proxies,
  1079. local_files_only=local_files_only,
  1080. token=token,
  1081. user_agent=user_agent,
  1082. revision=revision,
  1083. subfolder=subfolder,
  1084. _commit_hash=commit_hash,
  1085. )
  1086. else:
  1087. checkpoint_files = [resolved_archive_file] if pretrained_model_name_or_path is not None else None
  1088. return checkpoint_files, sharded_metadata
  1089. def _get_dtype(
  1090. cls,
  1091. dtype: Optional[Union[str, torch.dtype, dict]],
  1092. checkpoint_files: Optional[list[str]],
  1093. config: PretrainedConfig,
  1094. sharded_metadata: Optional[dict],
  1095. state_dict: Optional[dict],
  1096. weights_only: bool,
  1097. ) -> tuple[PretrainedConfig, Optional[torch.dtype], Optional[torch.dtype]]:
  1098. """Find the correct `dtype` to use based on provided arguments. Also update the `config` based on the
  1099. inferred dtype. We do the following:
  1100. 1. If dtype is not None, we use that dtype
  1101. 2. If dtype is "auto", we auto-detect dtype from the loaded state_dict, by checking its first
  1102. weights entry that is of a floating type - we assume all floating dtype weights are of the same dtype
  1103. we also may have config.dtype available, but we won't rely on it till v5
  1104. """
  1105. dtype_orig = None
  1106. is_sharded = sharded_metadata is not None
  1107. if dtype is not None:
  1108. if isinstance(dtype, str):
  1109. if dtype == "auto":
  1110. if hasattr(config, "dtype") and config.dtype is not None:
  1111. dtype = config.dtype
  1112. logger.info(f"Will use dtype={dtype} as defined in model's config object")
  1113. else:
  1114. if is_sharded and "dtype" in sharded_metadata:
  1115. dtype = sharded_metadata["dtype"]
  1116. elif state_dict is not None:
  1117. dtype = get_state_dict_dtype(state_dict)
  1118. else:
  1119. state_dict = load_state_dict(
  1120. checkpoint_files[0], map_location="meta", weights_only=weights_only
  1121. )
  1122. dtype = get_state_dict_dtype(state_dict)
  1123. logger.info(
  1124. "Since the `dtype` attribute can't be found in model's config object, "
  1125. "will use dtype={dtype} as derived from model's weights"
  1126. )
  1127. elif hasattr(torch, dtype):
  1128. dtype = getattr(torch, dtype)
  1129. config.dtype = dtype
  1130. for sub_config_key in config.sub_configs:
  1131. sub_config = getattr(config, sub_config_key)
  1132. sub_config.dtype = dtype
  1133. elif isinstance(dtype, torch.dtype):
  1134. config.dtype = dtype
  1135. for sub_config_key in config.sub_configs:
  1136. sub_config = getattr(config, sub_config_key)
  1137. sub_config.dtype = dtype
  1138. elif isinstance(dtype, dict):
  1139. for key, curr_dtype in dtype.items():
  1140. if hasattr(config, key):
  1141. value = getattr(config, key)
  1142. curr_dtype = curr_dtype if not isinstance(curr_dtype, str) else getattr(torch, curr_dtype)
  1143. value.dtype = curr_dtype
  1144. # main torch dtype for modules that aren't part of any sub-config
  1145. dtype = dtype.get("")
  1146. dtype = dtype if not isinstance(dtype, str) else getattr(torch, dtype)
  1147. config.dtype = dtype
  1148. if dtype is None:
  1149. dtype = torch.float32
  1150. else:
  1151. raise ValueError(
  1152. f"`dtype` can be one of: `torch.dtype`, `'auto'`, a string of a valid `torch.dtype` or a `dict` with valid `dtype` "
  1153. f"for each sub-config in composite configs, but received {dtype}"
  1154. )
  1155. dtype_orig = cls._set_default_dtype(dtype)
  1156. else:
  1157. # set fp32 as the default dtype for BC
  1158. default_dtype = torch.get_default_dtype()
  1159. config.dtype = default_dtype
  1160. for key in config.sub_configs:
  1161. value = getattr(config, key)
  1162. value.dtype = default_dtype
  1163. return config, dtype, dtype_orig
  1164. def _get_device_map(
  1165. model: "PreTrainedModel",
  1166. device_map: Optional[Union[dict, str]],
  1167. max_memory: Optional[dict],
  1168. hf_quantizer: Optional[HfQuantizer],
  1169. dtype: Optional[torch.dtype],
  1170. keep_in_fp32_regex: Optional[re.Pattern],
  1171. ) -> dict:
  1172. """Compute the final `device_map` to use if we passed a value in ['auto', 'balanced', 'balanced_low_0', 'sequential'].
  1173. Otherwise, we check for any device inconsistencies in the device_map.
  1174. """
  1175. if isinstance(device_map, str):
  1176. special_dtypes = {}
  1177. if hf_quantizer is not None:
  1178. special_dtypes.update(hf_quantizer.get_special_dtypes_update(model, dtype))
  1179. if keep_in_fp32_regex is not None:
  1180. special_dtypes.update(
  1181. {name: torch.float32 for name, _ in model.named_parameters() if keep_in_fp32_regex.search(name)}
  1182. )
  1183. target_dtype = dtype
  1184. if hf_quantizer is not None:
  1185. target_dtype = hf_quantizer.adjust_target_dtype(target_dtype)
  1186. no_split_modules = model._get_no_split_modules(device_map)
  1187. device_map_kwargs = {"no_split_module_classes": no_split_modules}
  1188. if "special_dtypes" in inspect.signature(infer_auto_device_map).parameters:
  1189. device_map_kwargs["special_dtypes"] = special_dtypes
  1190. elif len(special_dtypes) > 0:
  1191. logger.warning(
  1192. "This model has some weights that should be kept in higher precision, you need to upgrade "
  1193. "`accelerate` to properly deal with them (`pip install --upgrade accelerate`)."
  1194. )
  1195. if device_map != "sequential":
  1196. inferred_max_memory = get_balanced_memory(
  1197. model,
  1198. dtype=target_dtype,
  1199. low_zero=(device_map == "balanced_low_0"),
  1200. max_memory=max_memory,
  1201. **device_map_kwargs,
  1202. )
  1203. else:
  1204. inferred_max_memory = get_max_memory(max_memory)
  1205. if hf_quantizer is not None:
  1206. inferred_max_memory = hf_quantizer.adjust_max_memory(inferred_max_memory)
  1207. # `inferred_max_memory` contains non-reserved memory. There may be *unused* reserved memory in the GPU,
  1208. # which we can use to allocate parameters.
  1209. for device_name in inferred_max_memory:
  1210. if isinstance(device_name, int): # it's a GPU device
  1211. if is_torch_xpu_available():
  1212. unused_memory = torch.xpu.memory_reserved(device_name) - torch.xpu.memory_allocated(device_name)
  1213. else:
  1214. unused_memory = torch.cuda.memory_reserved(device_name) - torch.cuda.memory_allocated(device_name)
  1215. inferred_max_memory[device_name] += unused_memory
  1216. # respect the `max_memory` passed by the user
  1217. if max_memory is not None and device_name in max_memory:
  1218. inferred_max_memory[device_name] = min(inferred_max_memory[device_name], max_memory[device_name])
  1219. device_map_kwargs["max_memory"] = inferred_max_memory
  1220. device_map = infer_auto_device_map(model, dtype=target_dtype, **device_map_kwargs)
  1221. if hf_quantizer is not None:
  1222. hf_quantizer.validate_environment(device_map=device_map)
  1223. elif device_map is not None:
  1224. tied_params = find_tied_parameters(model)
  1225. # check if we don't have tied param in different devices
  1226. check_tied_parameters_on_same_device(tied_params, device_map)
  1227. return device_map
  1228. def _find_missing_and_unexpected_keys(
  1229. model: "PreTrainedModel",
  1230. original_checkpoint_keys: list[str],
  1231. checkpoint_keys: list[str],
  1232. loading_base_model_from_task_state_dict: bool,
  1233. hf_quantizer: Optional[HfQuantizer],
  1234. ) -> tuple[list[str], list[str]]:
  1235. """Find missing keys (keys that are part of the model parameters but were NOT found in the loaded state dict keys) and unexpected keys
  1236. (keys found in the loaded state dict keys, but that are NOT part of the model parameters)
  1237. """
  1238. prefix = model.base_model_prefix
  1239. # Compute expected keys, i.e. keys that the full model expects
  1240. expected_keys = list(model.state_dict().keys())
  1241. if hf_quantizer is not None:
  1242. expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
  1243. # Adjust prefix of the keys to make them match loaded keys before removing them
  1244. missing_keys = sorted(set(expected_keys) - set(checkpoint_keys))
  1245. unexpected_keys = set(checkpoint_keys) - set(expected_keys)
  1246. # If a module has the same name under the base and task specific model, we have to re-add it to unexpected keys
  1247. if loading_base_model_from_task_state_dict:
  1248. task_specific_keys = [k for k in original_checkpoint_keys if not k.startswith(f"{prefix}.")]
  1249. unexpected_keys.update(task_specific_keys)
  1250. # Remove nonpersistent buffers from unexpected keys: they are not in the expected keys (model state dict), but
  1251. # may be in the loaded keys. Note that removing all buffers does the job, as they were part of the expected keys anyway
  1252. model_buffers = {n for n, _ in model.named_buffers()}
  1253. unexpected_keys = sorted(unexpected_keys - model_buffers)
  1254. tied_params = find_tied_parameters(model)
  1255. for group in tied_params:
  1256. missing_in_group = [k for k in missing_keys if k in group]
  1257. if len(missing_in_group) > 0 and len(missing_in_group) < len(group):
  1258. missing_keys = [k for k in missing_keys if k not in missing_in_group]
  1259. if hf_quantizer is not None:
  1260. missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix)
  1261. unexpected_keys = hf_quantizer.update_unexpected_keys(model, unexpected_keys)
  1262. return missing_keys, unexpected_keys
  1263. def _find_mismatched_keys(
  1264. model: "PreTrainedModel",
  1265. state_dict: Optional[dict],
  1266. checkpoint_files: Optional[list[str]],
  1267. ignore_mismatched_sizes: bool,
  1268. keys_to_rename_mapping: dict[str, str],
  1269. is_quantized: bool,
  1270. weights_only: bool,
  1271. ) -> tuple[list[str], list[tuple[int, int]]]:
  1272. """
  1273. Find potential shape mismatch between the different state dicts and the model parameters, but only if `ignore_mismatched_sizes`
  1274. is True. Otherwise, return immediately and any shape mismatch that may exist will be raised later on. This avoids checking
  1275. every parameter in advance, as shape mismatch are extremely rare in practice. If we want to ignore them however, we do
  1276. need to check in advance as we need to know which parameters we need to move back from meta to cpu, and initialize
  1277. correctly. Indeed, as our model initialization takes place at the module level, and not the weight level, in the
  1278. case of a sharded checkpoint we cannot correctly initialize the weights according to `model._init_weights()` if we perform
  1279. this check on each state dict at loading time (after the first loaded checkpoint, there are no way to initialize only the
  1280. mismatched weights if any, without overwriting the previously loaded weights as well because all the module will be
  1281. initialized, not only the weights that are mismatched).
  1282. """
  1283. # An error will be raised later on anyway if there is a mismatch - this avoids running the rest of this function
  1284. # if there are no mismatch (which is almost always the case)
  1285. if not ignore_mismatched_sizes:
  1286. return [], []
  1287. if state_dict is not None:
  1288. checkpoint_files = [""]
  1289. model_state_dict = model.state_dict()
  1290. mismatched_keys = []
  1291. mismatched_shapes = []
  1292. for shard_file in checkpoint_files:
  1293. # If shard_file is "", we use the existing state_dict instead of loading it
  1294. if shard_file != "":
  1295. state_dict = load_state_dict(
  1296. shard_file, is_quantized=is_quantized, map_location="meta", weights_only=weights_only
  1297. )
  1298. # Fix the key names
  1299. new_state_dict = {keys_to_rename_mapping[k]: v for k, v in state_dict.items() if k in keys_to_rename_mapping}
  1300. for key, tensor in new_state_dict.items():
  1301. if key in model_state_dict and tensor.shape != model_state_dict[key].shape:
  1302. # This skips size mismatches for 4-bit weights. Two 4-bit values share an 8-bit container, causing size differences.
  1303. # Without matching with module type or parameter type it seems like a practical way to detect valid 4bit weights.
  1304. if not (
  1305. is_quantized and tensor.shape[-1] == 1 and tensor.numel() * 2 == model_state_dict[key].numel()
  1306. ):
  1307. mismatched_keys.append(key)
  1308. mismatched_shapes.append((tensor.shape, model_state_dict[key].shape))
  1309. return mismatched_keys, mismatched_shapes
  1310. class PipelineParallel(Enum):
  1311. inputs = 0
  1312. outputs = 1
  1313. class ModuleUtilsMixin:
  1314. """
  1315. A few utilities for `torch.nn.Modules`, to be used as a mixin.
  1316. """
  1317. @staticmethod
  1318. def _hook_rss_memory_pre_forward(module, *args, **kwargs):
  1319. try:
  1320. import psutil
  1321. except ImportError:
  1322. raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
  1323. process = psutil.Process(os.getpid())
  1324. mem = process.memory_info()
  1325. module.mem_rss_pre_forward = mem.rss
  1326. return None
  1327. @staticmethod
  1328. def _hook_rss_memory_post_forward(module, *args, **kwargs):
  1329. try:
  1330. import psutil
  1331. except ImportError:
  1332. raise ImportError("You need to install psutil (pip install psutil) to use memory tracing.")
  1333. process = psutil.Process(os.getpid())
  1334. mem = process.memory_info()
  1335. module.mem_rss_post_forward = mem.rss
  1336. mem_rss_diff = module.mem_rss_post_forward - module.mem_rss_pre_forward
  1337. module.mem_rss_diff = mem_rss_diff + (module.mem_rss_diff if hasattr(module, "mem_rss_diff") else 0)
  1338. return None
  1339. def add_memory_hooks(self):
  1340. """
  1341. Add a memory hook before and after each sub-module forward pass to record increase in memory consumption.
  1342. Increase in memory consumption is stored in a `mem_rss_diff` attribute for each module and can be reset to zero
  1343. with `model.reset_memory_hooks_state()`.
  1344. """
  1345. for module in self.modules():
  1346. module.register_forward_pre_hook(self._hook_rss_memory_pre_forward)
  1347. module.register_forward_hook(self._hook_rss_memory_post_forward)
  1348. self.reset_memory_hooks_state()
  1349. def reset_memory_hooks_state(self):
  1350. """
  1351. Reset the `mem_rss_diff` attribute of each module (see [`~modeling_utils.ModuleUtilsMixin.add_memory_hooks`]).
  1352. """
  1353. for module in self.modules():
  1354. module.mem_rss_diff = 0
  1355. module.mem_rss_post_forward = 0
  1356. module.mem_rss_pre_forward = 0
  1357. @property
  1358. def device(self) -> torch.device:
  1359. """
  1360. `torch.device`: The device on which the module is (assuming that all the module parameters are on the same
  1361. device).
  1362. """
  1363. return get_parameter_device(self)
  1364. @property
  1365. def dtype(self) -> torch.dtype:
  1366. """
  1367. `torch.dtype`: The dtype of the module (assuming that all the module parameters have the same dtype).
  1368. """
  1369. return get_parameter_dtype(self)
  1370. def invert_attention_mask(self, encoder_attention_mask: Tensor) -> Tensor:
  1371. """
  1372. Invert an attention mask (e.g., switches 0. and 1.).
  1373. Args:
  1374. encoder_attention_mask (`torch.Tensor`): An attention mask.
  1375. Returns:
  1376. `torch.Tensor`: The inverted attention mask.
  1377. """
  1378. if encoder_attention_mask.dim() == 3:
  1379. encoder_extended_attention_mask = encoder_attention_mask[:, None, :, :]
  1380. if encoder_attention_mask.dim() == 2:
  1381. encoder_extended_attention_mask = encoder_attention_mask[:, None, None, :]
  1382. # T5 has a mask that can compare sequence ids, we can simulate this here with this transposition
  1383. # Cf. https://github.com/tensorflow/mesh/blob/8d2465e9bc93129b913b5ccc6a59aa97abd96ec6/mesh_tensorflow
  1384. # /transformer/transformer_layers.py#L270
  1385. # encoder_extended_attention_mask = (encoder_extended_attention_mask ==
  1386. # encoder_extended_attention_mask.transpose(-1, -2))
  1387. encoder_extended_attention_mask = encoder_extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  1388. encoder_extended_attention_mask = (1.0 - encoder_extended_attention_mask) * torch.finfo(self.dtype).min
  1389. return encoder_extended_attention_mask
  1390. @staticmethod
  1391. def create_extended_attention_mask_for_decoder(input_shape, attention_mask, device=None):
  1392. if device is not None:
  1393. warnings.warn(
  1394. "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
  1395. )
  1396. else:
  1397. device = attention_mask.device
  1398. batch_size, seq_length = input_shape
  1399. seq_ids = torch.arange(seq_length, device=device)
  1400. causal_mask = seq_ids[None, None, :].repeat(batch_size, seq_length, 1) <= seq_ids[None, :, None]
  1401. # in case past_key_values are used we need to add a prefix ones mask to the causal mask
  1402. causal_mask = causal_mask.to(attention_mask.dtype)
  1403. if causal_mask.shape[1] < attention_mask.shape[1]:
  1404. prefix_seq_len = attention_mask.shape[1] - causal_mask.shape[1]
  1405. causal_mask = torch.cat(
  1406. [
  1407. torch.ones((batch_size, seq_length, prefix_seq_len), device=device, dtype=causal_mask.dtype),
  1408. causal_mask,
  1409. ],
  1410. axis=-1,
  1411. )
  1412. extended_attention_mask = causal_mask[:, None, :, :] * attention_mask[:, None, None, :]
  1413. return extended_attention_mask
  1414. def get_extended_attention_mask(
  1415. self,
  1416. attention_mask: Tensor,
  1417. input_shape: tuple[int, ...],
  1418. device: Optional[torch.device] = None,
  1419. dtype: Optional[torch.dtype] = None,
  1420. ) -> Tensor:
  1421. """
  1422. Makes broadcastable attention and causal masks so that future and masked tokens are ignored.
  1423. Arguments:
  1424. attention_mask (`torch.Tensor`):
  1425. Mask with ones indicating tokens to attend to, zeros for tokens to ignore.
  1426. input_shape (`tuple[int]`):
  1427. The shape of the input to the model.
  1428. Returns:
  1429. `torch.Tensor` The extended attention mask, with a the same dtype as `attention_mask.dtype`.
  1430. """
  1431. if dtype is None:
  1432. dtype = self.dtype
  1433. if not (attention_mask.dim() == 2 and self.config.is_decoder):
  1434. # show warning only if it won't be shown in `create_extended_attention_mask_for_decoder`
  1435. if device is not None:
  1436. warnings.warn(
  1437. "The `device` argument is deprecated and will be removed in v5 of Transformers.", FutureWarning
  1438. )
  1439. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  1440. # ourselves in which case we just need to make it broadcastable to all heads.
  1441. if attention_mask.dim() == 3:
  1442. extended_attention_mask = attention_mask[:, None, :, :]
  1443. elif attention_mask.dim() == 2:
  1444. # Provided a padding mask of dimensions [batch_size, seq_length]
  1445. # - if the model is a decoder, apply a causal mask in addition to the padding mask
  1446. # - if the model is an encoder, make the mask broadcastable to [batch_size, num_heads, seq_length, seq_length]
  1447. if self.config.is_decoder:
  1448. extended_attention_mask = ModuleUtilsMixin.create_extended_attention_mask_for_decoder(
  1449. input_shape, attention_mask, device
  1450. )
  1451. else:
  1452. extended_attention_mask = attention_mask[:, None, None, :]
  1453. else:
  1454. raise ValueError(
  1455. f"Wrong shape for input_ids (shape {input_shape}) or attention_mask (shape {attention_mask.shape})"
  1456. )
  1457. # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
  1458. # masked positions, this operation will create a tensor which is 0.0 for
  1459. # positions we want to attend and the dtype's smallest value for masked positions.
  1460. # Since we are adding it to the raw scores before the softmax, this is
  1461. # effectively the same as removing these entirely.
  1462. extended_attention_mask = extended_attention_mask.to(dtype=dtype) # fp16 compatibility
  1463. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(dtype).min
  1464. return extended_attention_mask
  1465. def get_head_mask(
  1466. self, head_mask: Optional[Tensor], num_hidden_layers: int, is_attention_chunked: bool = False
  1467. ) -> Tensor:
  1468. """
  1469. Prepare the head mask if needed.
  1470. Args:
  1471. head_mask (`torch.Tensor` with shape `[num_heads]` or `[num_hidden_layers x num_heads]`, *optional*):
  1472. The mask indicating if we should keep the heads or not (1.0 for keep, 0.0 for discard).
  1473. num_hidden_layers (`int`):
  1474. The number of hidden layers in the model.
  1475. is_attention_chunked (`bool`, *optional*, defaults to `False`):
  1476. Whether or not the attentions scores are computed by chunks or not.
  1477. Returns:
  1478. `torch.Tensor` with shape `[num_hidden_layers x batch x num_heads x seq_length x seq_length]` or list with
  1479. `[None]` for each layer.
  1480. """
  1481. if head_mask is not None:
  1482. head_mask = self._convert_head_mask_to_5d(head_mask, num_hidden_layers)
  1483. if is_attention_chunked is True:
  1484. head_mask = head_mask.unsqueeze(-1)
  1485. else:
  1486. head_mask = [None] * num_hidden_layers
  1487. return head_mask
  1488. def _convert_head_mask_to_5d(self, head_mask, num_hidden_layers):
  1489. """-> [num_hidden_layers x batch x num_heads x seq_length x seq_length]"""
  1490. if head_mask.dim() == 1:
  1491. head_mask = head_mask.unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
  1492. head_mask = head_mask.expand(num_hidden_layers, -1, -1, -1, -1)
  1493. elif head_mask.dim() == 2:
  1494. head_mask = head_mask.unsqueeze(1).unsqueeze(-1).unsqueeze(-1) # We can specify head_mask for each layer
  1495. assert head_mask.dim() == 5, f"head_mask.dim != 5, instead {head_mask.dim()}"
  1496. head_mask = head_mask.to(dtype=self.dtype) # switch to float if need + fp16 compatibility
  1497. return head_mask
  1498. def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
  1499. """
  1500. Get number of (optionally, trainable or non-embeddings) parameters in the module.
  1501. Args:
  1502. only_trainable (`bool`, *optional*, defaults to `False`):
  1503. Whether or not to return only the number of trainable parameters
  1504. exclude_embeddings (`bool`, *optional*, defaults to `False`):
  1505. Whether or not to return only the number of non-embeddings parameters
  1506. Returns:
  1507. `int`: The number of parameters.
  1508. """
  1509. if exclude_embeddings:
  1510. embedding_param_names = [
  1511. f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
  1512. ]
  1513. total_parameters = [
  1514. parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
  1515. ]
  1516. else:
  1517. total_parameters = list(self.parameters())
  1518. total_numel = []
  1519. is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
  1520. if is_loaded_in_4bit:
  1521. if is_bitsandbytes_available():
  1522. import bitsandbytes as bnb
  1523. else:
  1524. raise ValueError(
  1525. "bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
  1526. " make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
  1527. )
  1528. for param in total_parameters:
  1529. if param.requires_grad or not only_trainable:
  1530. # For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
  1531. # used for the 4bit quantization (uint8 tensors are stored)
  1532. if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
  1533. if hasattr(param, "element_size"):
  1534. num_bytes = param.element_size()
  1535. elif hasattr(param, "quant_storage"):
  1536. num_bytes = param.quant_storage.itemsize
  1537. else:
  1538. num_bytes = 1
  1539. total_numel.append(param.numel() * 2 * num_bytes)
  1540. else:
  1541. total_numel.append(param.numel())
  1542. return sum(total_numel)
  1543. def estimate_tokens(self, input_dict: dict[str, Union[torch.Tensor, Any]]) -> int:
  1544. """
  1545. Helper function to estimate the total number of tokens from the model inputs.
  1546. Args:
  1547. inputs (`dict`): The model inputs.
  1548. Returns:
  1549. `int`: The total number of tokens.
  1550. """
  1551. if not hasattr(self, "warnings_issued"):
  1552. self.warnings_issued = {}
  1553. if self.main_input_name in input_dict:
  1554. return input_dict[self.main_input_name].numel()
  1555. elif "estimate_tokens" not in self.warnings_issued:
  1556. logger.warning(
  1557. "Could not estimate the number of tokens of the input, floating-point operations will not be computed"
  1558. )
  1559. self.warnings_issued["estimate_tokens"] = True
  1560. return 0
  1561. def floating_point_ops(
  1562. self, input_dict: dict[str, Union[torch.Tensor, Any]], exclude_embeddings: bool = True
  1563. ) -> int:
  1564. """
  1565. Get number of (optionally, non-embeddings) floating-point operations for the forward and backward passes of a
  1566. batch with this transformer model. Default approximation neglects the quadratic dependency on the number of
  1567. tokens (valid if `12 * d_model << sequence_length`) as laid out in [this
  1568. paper](https://huggingface.co/papers/2001.08361) section 2.1. Should be overridden for transformers with parameter
  1569. re-use e.g. Albert or Universal Transformers, or if doing long-range modeling with very high sequence lengths.
  1570. Args:
  1571. batch_size (`int`):
  1572. The batch size for the forward pass.
  1573. sequence_length (`int`):
  1574. The number of tokens in each line of the batch.
  1575. exclude_embeddings (`bool`, *optional*, defaults to `True`):
  1576. Whether or not to count embedding and softmax operations.
  1577. Returns:
  1578. `int`: The number of floating-point operations.
  1579. """
  1580. return 6 * self.estimate_tokens(input_dict) * self.num_parameters(exclude_embeddings=exclude_embeddings)
  1581. class EmbeddingAccessMixin:
  1582. """
  1583. Base utilities to regroup getters and setters for embeddings.
  1584. Introduces the `input_layer_embed` attribute, which indicates
  1585. where the input embeddings come from and where they
  1586. should be set.
  1587. """
  1588. _input_embed_layer = "embed_tokens" # default layer that holds input embeddings.
  1589. def get_input_embeddings(self) -> nn.Module:
  1590. """
  1591. Returns the model's input embeddings.
  1592. Returns:
  1593. `nn.Module`: A torch module mapping vocabulary to hidden states.
  1594. """
  1595. # 1) Check if the model has an attribute named 'embed_tokens' (the standard input embedding layer
  1596. # for most NLP models), and if so, return it.
  1597. name = getattr(self, "_input_embed_layer", "embed_tokens")
  1598. if (default_embedding := getattr(self, name, None)) is not None:
  1599. return default_embedding
  1600. # 2) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
  1601. if hasattr(self, "model") and hasattr(self.model, "embed_tokens"):
  1602. return self.model.embed_tokens
  1603. # 3) vanilla decoder‑only architectures
  1604. elif hasattr(self, "embed_tokens"):
  1605. return self.embed_tokens
  1606. else:
  1607. base_model = getattr(self, "base_model_prefix", None)
  1608. if base_model is not None:
  1609. base_model = getattr(self, base_model, None)
  1610. if base_model is not None and base_model is not self:
  1611. return base_model.get_input_embeddings()
  1612. raise NotImplementedError(
  1613. f"`get_input_embeddings` not auto‑handled for {self.__class__.__name__}; "
  1614. "please override in the subclass."
  1615. )
  1616. def set_input_embeddings(self, value: nn.Module):
  1617. """Fallback setter that handles **~70%** of models in the code-base.
  1618. Order of attempts:
  1619. 1. `self.model.embed_tokens`
  1620. 2. `self.embed_tokens`
  1621. 3. delegate to the *base model* if one exists
  1622. 4. otherwise raise `NotImplementedError` so subclasses still can (and
  1623. should) override for exotic layouts.
  1624. """
  1625. # 1) encoder/decoder and VLMs like `Gemma3nForConditionalGeneration`
  1626. name = getattr(self, "_input_embed_layer", "embed_tokens")
  1627. if hasattr(self, "model") and hasattr(self.model, name):
  1628. setattr(self.model, name, value)
  1629. # 2) as well as vanilla decoder‑only architectures
  1630. elif hasattr(self, name):
  1631. setattr(self, name, value)
  1632. # 3) recurse once into the registered *base* model (e.g. for encoder/decoder)
  1633. elif getattr(self, self.base_model_prefix, self) is not self:
  1634. base_model = getattr(self, self.base_model_prefix, self)
  1635. base_model.set_input_embeddings(value)
  1636. else:
  1637. raise NotImplementedError(
  1638. f"`set_input_embeddings` not auto‑handled for {self.__class__.__name__}; please override in the subclass."
  1639. )
  1640. def get_output_embeddings(self):
  1641. if not hasattr(self, "lm_head"):
  1642. return None
  1643. try:
  1644. # Speech / vision backbones raise here, so we return None.
  1645. # Legit use of get_input_embs?
  1646. self.get_input_embeddings()
  1647. except NotImplementedError:
  1648. return None
  1649. return self.lm_head
  1650. def set_output_embeddings(self, new_embeddings):
  1651. """
  1652. Sets the model's output embedding, defaulting to setting new_embeddings to lm_head.
  1653. """
  1654. if getattr(self, "lm_head"):
  1655. self.lm_head = new_embeddings
  1656. class PreTrainedModel(nn.Module, EmbeddingAccessMixin, ModuleUtilsMixin, PushToHubMixin, PeftAdapterMixin):
  1657. r"""
  1658. Base class for all models.
  1659. [`PreTrainedModel`] takes care of storing the configuration of the models and handles methods for loading,
  1660. downloading and saving models as well as a few methods common to all models to:
  1661. - resize the input embeddings,
  1662. - prune heads in the self-attention heads.
  1663. Class attributes (overridden by derived classes):
  1664. - **config_class** ([`PretrainedConfig`]) -- A subclass of [`PretrainedConfig`] to use as configuration class
  1665. for this model architecture.
  1666. - **load_tf_weights** (`Callable`) -- A python *method* for loading a TensorFlow checkpoint in a PyTorch model,
  1667. taking as arguments:
  1668. - **model** ([`PreTrainedModel`]) -- An instance of the model on which to load the TensorFlow checkpoint.
  1669. - **config** ([`PreTrainedConfig`]) -- An instance of the configuration associated to the model.
  1670. - **path** (`str`) -- A path to the TensorFlow checkpoint.
  1671. - **base_model_prefix** (`str`) -- A string indicating the attribute associated to the base model in derived
  1672. classes of the same architecture adding modules on top of the base model.
  1673. - **is_parallelizable** (`bool`) -- A flag indicating whether this model supports model parallelization.
  1674. - **main_input_name** (`str`) -- The name of the principal input to the model (often `input_ids` for NLP
  1675. models, `pixel_values` for vision models and `input_values` for speech models).
  1676. - **can_record_outputs** (dict):"""
  1677. config_class = None
  1678. base_model_prefix = ""
  1679. main_input_name = "input_ids"
  1680. model_tags = None
  1681. _checkpoint_conversion_mapping = {} # used for BC support in VLMs, not meant to be used by new models
  1682. _auto_class = None
  1683. _no_split_modules = None
  1684. _skip_keys_device_placement = None
  1685. _keep_in_fp32_modules = None
  1686. # the _keep_in_fp32_modules will avoid casting to anything other than float32, except bfloat16
  1687. # to also prevent bfloat16 casting, use the _keep_in_fp32_modules_strict flag
  1688. _keep_in_fp32_modules_strict = None
  1689. # a list of `re` patterns of `state_dict` keys that should be removed from the list of missing
  1690. # keys we find (keys inside the model but not in the checkpoint) and avoid unnecessary warnings.
  1691. _keys_to_ignore_on_load_missing = None
  1692. # a list of `re` patterns of `state_dict` keys that should be removed from the list of
  1693. # unexpected keys we find (keys inside the checkpoint but not the model) and avoid unnecessary
  1694. # warnings.
  1695. _keys_to_ignore_on_load_unexpected = None
  1696. # a list of `state_dict` keys to ignore when saving the model (useful for keys that aren't
  1697. # trained, but which are either deterministic or tied variables)
  1698. _keys_to_ignore_on_save = None
  1699. # a list of `state_dict` keys that are potentially tied to another key in the state_dict.
  1700. _tied_weights_keys = None
  1701. is_parallelizable = False
  1702. supports_gradient_checkpointing = False
  1703. _is_stateful = False
  1704. # Flash Attention support
  1705. _supports_flash_attn = False
  1706. # SDPA support
  1707. _supports_sdpa = False
  1708. # Flex Attention support
  1709. _supports_flex_attn = False
  1710. _can_compile_fullgraph = False
  1711. # A tensor parallel plan to be applied to the model when TP is enabled. For
  1712. # top-level models, this attribute is currently defined in respective model
  1713. # code. For base models, this attribute comes from
  1714. # `config.base_model_tp_plan` during `__init__`.
  1715. # It should identify the layers exactly: if you want to TP model.language_model.layers.fc1
  1716. # by passing `tp_plan` to the init, it should be {"model.language_model.layers.fc1":"colwise"}
  1717. # for example.
  1718. _tp_plan = None
  1719. # tensor parallel degree to which model is sharded to.
  1720. _tp_size = None
  1721. # A pipeline parallel plan specifying the layers which may not be present
  1722. # on all ranks when PP is enabled. For top-level models, this attribute is
  1723. # currently defined in respective model code. For base models, this
  1724. # attribute comes from `config.base_model_pp_plan` during `post_init`.
  1725. #
  1726. # The variable names for the inputs and outputs of the specified layers can
  1727. # be indexed using the `PipelineParallel` enum as follows:
  1728. # - `_pp_plan["layers"][PipelineParallel.inputs]`
  1729. # - `_pp_plan["layers"][PipelineParallel.outputs]`
  1730. _pp_plan = None
  1731. # This flag signal that the model can be used as an efficient backend in TGI and vLLM
  1732. # In practice, it means that they support attention (mask) interface functions, fully pass the kwargs
  1733. # through all modules up to the Attention layer, can slice logits with Tensor, and have a default TP plan
  1734. _supports_attention_backend = False
  1735. _can_record_outputs = None
  1736. @property
  1737. @torch._dynamo.allow_in_graph
  1738. def can_record_outputs(self) -> dict[str, OutputRecorder]:
  1739. """
  1740. Maps output names (e.g., "attentions", "hidden_states")
  1741. to either:
  1742. - A module class (e.g., `LlamaDecoderLayer`), using default index conventions:
  1743. * index=0 for "hidden_states"
  1744. * index=1 for "attentions"
  1745. - Or an `OutputRecorder(...)` with `target_class`, optional `index`, and `layer_name`.
  1746. Examples:
  1747. These two are equivalent:
  1748. ```python
  1749. _can_record_outputs = {
  1750. "attentions": LlamaAttention,
  1751. "hidden_states": LlamaDecoderLayer
  1752. }
  1753. _can_record_outputs = {
  1754. "attentions": OutputRecorder(LlamaAttention, index=1),
  1755. "hidden_states": OutputRecorder(LlamaDecoderLayer, index=0)
  1756. }
  1757. ```
  1758. This means you can record outputs from the same class, by specifying a layer name. Before
  1759. collecting outputs, we check that they come from this layer.
  1760. If you have cross attention that come from `LlamaAttention` and self attention that also
  1761. come from `LlamaAttention` but from `self_attn` you can do this:
  1762. ```python
  1763. class LlamaModel(PreTrainedModel):
  1764. _can_record_outputs = {
  1765. "attentions": OutputRecorder(LlamaAttention, index=1, layer-name="self_attn"),
  1766. "cross_attentions": OutputRecorder(LlamaAttention, index=1, layer_name="cross_attn")
  1767. }
  1768. ```
  1769. """
  1770. return self._can_record_outputs or {}
  1771. @property
  1772. def dummy_inputs(self) -> dict[str, torch.Tensor]:
  1773. """
  1774. `dict[str, torch.Tensor]`: Dummy inputs to do a forward pass in the network.
  1775. """
  1776. return {"input_ids": torch.tensor(DUMMY_INPUTS)}
  1777. @property
  1778. def framework(self) -> str:
  1779. """
  1780. :str: Identifies that this is a PyTorch model.
  1781. """
  1782. return "pt"
  1783. def __init_subclass__(cls, **kwargs):
  1784. super().__init_subclass__(**kwargs)
  1785. # For BC we keep the original `config_class` definition in case
  1786. # there is a `config_class` attribute (e.g. remote code models),
  1787. # otherwise we derive it from the annotated `config` attribute.
  1788. # defined in this particular subclass
  1789. child_annotation = cls.__dict__.get("__annotations__", {}).get("config", None)
  1790. child_attribute = cls.__dict__.get("config_class", None)
  1791. # defined in the class (this subclass or any parent class)
  1792. full_annotation = get_type_hints(cls).get("config", None)
  1793. full_attribute = cls.config_class
  1794. # priority (child class_config -> child annotation -> global class_config -> global annotation)
  1795. if child_attribute is not None:
  1796. cls.config_class = child_attribute
  1797. elif child_annotation is not None:
  1798. cls.config_class = child_annotation
  1799. elif full_attribute is not None:
  1800. cls.config_class = full_attribute
  1801. elif full_annotation is not None:
  1802. cls.config_class = full_annotation
  1803. def __init__(self, config: PretrainedConfig, *inputs, **kwargs):
  1804. super().__init__()
  1805. if not isinstance(config, PretrainedConfig):
  1806. raise TypeError(
  1807. f"Parameter config in `{self.__class__.__name__}(config)` should be an instance of class "
  1808. "`PretrainedConfig`. To create a model from a pretrained model use "
  1809. f"`model = {self.__class__.__name__}.from_pretrained(PRETRAINED_MODEL_NAME)`"
  1810. )
  1811. self.config = config
  1812. # Check the attention implementation is supported, or set it if not yet set (on the internal attr, to avoid
  1813. # setting it recursively)
  1814. self.config._attn_implementation_internal = self._check_and_adjust_attn_implementation(
  1815. self.config._attn_implementation, is_init_check=True
  1816. )
  1817. # for initialization of the loss
  1818. loss_type = self.__class__.__name__
  1819. if loss_type not in LOSS_MAPPING:
  1820. loss_groups = f"({'|'.join(LOSS_MAPPING)})"
  1821. loss_type = re.findall(loss_groups, self.__class__.__name__)
  1822. if len(loss_type) > 0:
  1823. loss_type = loss_type[0]
  1824. else:
  1825. loss_type = None
  1826. self.loss_type = loss_type
  1827. self.name_or_path = config.name_or_path
  1828. self.warnings_issued = {}
  1829. self.generation_config = GenerationConfig.from_model_config(config) if self.can_generate() else None
  1830. # Overwrite the class attribute to make it an instance attribute, so models like
  1831. # `InstructBlipForConditionalGeneration` can dynamically update it without modifying the class attribute
  1832. # when a different component (e.g. language_model) is used.
  1833. self._keep_in_fp32_modules = copy.copy(self.__class__._keep_in_fp32_modules)
  1834. self._keep_in_fp32_modules_strict = copy.copy(self.__class__._keep_in_fp32_modules_strict)
  1835. self._no_split_modules = self._no_split_modules or []
  1836. _CAN_RECORD_REGISTRY[str(self.__class__)] = self._can_record_outputs # added for executorch support only
  1837. def post_init(self):
  1838. """
  1839. A method executed at the end of each Transformer model initialization, to execute code that needs the model's
  1840. modules properly initialized (such as weight initialization).
  1841. This is also used when the user is running distributed code. We add hooks to the modules here, according to
  1842. the model's tp_plan!
  1843. """
  1844. self.init_weights()
  1845. self._backward_compatibility_gradient_checkpointing()
  1846. # Make sure the modules correctly exist if the flag is active
  1847. if self._keep_in_fp32_modules is not None or self._keep_in_fp32_modules_strict is not None:
  1848. all_parameters = {name for name, _ in self.named_parameters() if len(name) > 0}
  1849. unique_module_names = set()
  1850. # Get all unique module names in the module graph, without the prefixes
  1851. for param in all_parameters:
  1852. unique_module_names.update(
  1853. [name for name in param.split(".") if not name.isnumeric() and name not in ["weight", "bias"]]
  1854. )
  1855. # Check that every module in the keep_in_fp32 list is part of the module graph
  1856. if self._keep_in_fp32_modules is not None:
  1857. for module in self._keep_in_fp32_modules:
  1858. if module not in unique_module_names:
  1859. raise ValueError(
  1860. f"{module} was specified in the `_keep_in_fp32_modules` list, but is not part of the modules in"
  1861. f" {self.__class__.__name__}"
  1862. )
  1863. if self._keep_in_fp32_modules_strict is not None:
  1864. for module in self._keep_in_fp32_modules_strict:
  1865. if module not in unique_module_names:
  1866. raise ValueError(
  1867. f"{module} was specified in the `_keep_in_fp32_modules_strict` list, but is not part of the modules in"
  1868. f" {self.__class__.__name__}"
  1869. )
  1870. # If current model is a base model, attach `base_model_tp_plan` and `base_model_pp_plan` from config
  1871. self._pp_plan = self.config.base_model_pp_plan.copy() if self.config.base_model_pp_plan is not None else {}
  1872. self._tp_plan = self.config.base_model_tp_plan.copy() if self.config.base_model_tp_plan is not None else {}
  1873. self._ep_plan = self.config.base_model_ep_plan.copy() if self.config.base_model_ep_plan is not None else {}
  1874. for name, module in self.named_children():
  1875. if plan := getattr(module, "_ep_plan", None):
  1876. self._ep_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
  1877. if plan := getattr(module, "_tp_plan", None):
  1878. self._tp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
  1879. if plan := getattr(module, "_pp_plan", None):
  1880. self._pp_plan.update({f"{name}.{k}": v for k, v in plan.copy().items()})
  1881. @property
  1882. def tp_plan(self) -> dict[str, str]:
  1883. """
  1884. The full tp plan for the model's modules
  1885. """
  1886. if hasattr(self.config, "distributed_config") and self.config.distributed_config.enable_expert_parallel:
  1887. return self._ep_plan
  1888. return self._tp_plan
  1889. @property
  1890. def pp_plan(self) -> dict[str, tuple[str, str]]:
  1891. return self._pp_plan
  1892. @tp_plan.setter
  1893. def tp_plan(self, plan: dict[str, str]):
  1894. if plan is not None:
  1895. # Validate that all parallel styles in the plan are supported
  1896. from .integrations.tensor_parallel import ALL_PARALLEL_STYLES
  1897. for layer_pattern, parallel_style in plan.items():
  1898. if parallel_style not in ALL_PARALLEL_STYLES:
  1899. raise ValueError(
  1900. f"Unsupported tensor parallel style '{parallel_style}' for layer '{layer_pattern}'. "
  1901. f"Supported styles are {list(ALL_PARALLEL_STYLES.keys())}"
  1902. )
  1903. # Validate that the layer patterns match existing model structure
  1904. # We check this by getting all parameter names and seeing if any match the patterns
  1905. if hasattr(self, "named_parameters"):
  1906. model_param_names = [name for name, _ in self.named_parameters()]
  1907. if model_param_names: # Only validate if model has parameters
  1908. for layer_pattern in plan.keys():
  1909. # Convert pattern to regex (replace * with .*)
  1910. regex_pattern = layer_pattern.replace("*", r"\d+")
  1911. pattern_matched = False
  1912. for param_name in model_param_names:
  1913. if re.match(regex_pattern, param_name):
  1914. pattern_matched = True
  1915. break
  1916. if not pattern_matched:
  1917. # Try more flexible matching - check if pattern components exist
  1918. pattern_parts = layer_pattern.split(".")
  1919. flexible_matched = False
  1920. for param_name in model_param_names:
  1921. param_parts = param_name.split(".")
  1922. if len(pattern_parts) <= len(param_parts):
  1923. match_count = 0
  1924. for i, pattern_part in enumerate(pattern_parts):
  1925. if pattern_part == "*":
  1926. match_count += 1
  1927. elif i < len(param_parts) and pattern_part == param_parts[i]:
  1928. match_count += 1
  1929. if match_count == len(pattern_parts):
  1930. flexible_matched = True
  1931. break
  1932. if not flexible_matched:
  1933. warnings.warn(
  1934. f"Layer pattern '{layer_pattern}' does not match any parameters in the model. "
  1935. f"This rule may not be applied during tensor parallelization."
  1936. )
  1937. self._tp_plan = plan if plan is not None else {}
  1938. @pp_plan.setter
  1939. def pp_plan(self, plan: dict[str, tuple[str, str]]):
  1940. self._pp_plan = plan
  1941. def dequantize(self):
  1942. """
  1943. Potentially dequantize the model in case it has been quantized by a quantization method that support
  1944. dequantization.
  1945. """
  1946. hf_quantizer = getattr(self, "hf_quantizer", None)
  1947. if hf_quantizer is None:
  1948. raise ValueError("You need to first quantize your model in order to dequantize it")
  1949. return hf_quantizer.dequantize(self)
  1950. def _backward_compatibility_gradient_checkpointing(self):
  1951. if self.supports_gradient_checkpointing and getattr(self.config, "gradient_checkpointing", False):
  1952. self.gradient_checkpointing_enable()
  1953. # Remove the attribute now that is has been consumed, so it's no saved in the config.
  1954. delattr(self.config, "gradient_checkpointing")
  1955. def add_model_tags(self, tags: Union[list[str], str]) -> None:
  1956. r"""
  1957. Add custom tags into the model that gets pushed to the Hugging Face Hub. Will
  1958. not overwrite existing tags in the model.
  1959. Args:
  1960. tags (`Union[list[str], str]`):
  1961. The desired tags to inject in the model
  1962. Examples:
  1963. ```python
  1964. from transformers import AutoModel
  1965. model = AutoModel.from_pretrained("google-bert/bert-base-cased")
  1966. model.add_model_tags(["custom", "custom-bert"])
  1967. # Push the model to your namespace with the name "my-custom-bert".
  1968. model.push_to_hub("my-custom-bert")
  1969. ```
  1970. """
  1971. if isinstance(tags, str):
  1972. tags = [tags]
  1973. if self.model_tags is None:
  1974. self.model_tags = []
  1975. for tag in tags:
  1976. if tag not in self.model_tags:
  1977. self.model_tags.append(tag)
  1978. @classmethod
  1979. @restore_default_dtype
  1980. def _from_config(cls, config, **kwargs):
  1981. """
  1982. All context managers that the model should be initialized under go here.
  1983. Args:
  1984. dtype (`torch.dtype`, *optional*):
  1985. Override the default `dtype` and load the model under this dtype.
  1986. """
  1987. # when we init a model from within another model (e.g. VLMs) and dispatch on FA2
  1988. # a warning is raised that dtype should be fp16. Since we never pass dtype from within
  1989. # modeling code, we can try to infer it here same way as done in `from_pretrained`
  1990. # For BC on the old `torch_dtype`
  1991. dtype = kwargs.pop("dtype", config.dtype)
  1992. if (torch_dtype := kwargs.pop("torch_dtype", None)) is not None:
  1993. logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
  1994. # if both kwargs are provided, use `dtype`
  1995. dtype = dtype if dtype != config.dtype else torch_dtype
  1996. if isinstance(dtype, str):
  1997. dtype = getattr(torch, dtype)
  1998. # override default dtype if needed
  1999. dtype_orig = None
  2000. if dtype is not None:
  2001. dtype_orig = cls._set_default_dtype(dtype)
  2002. # If passing `attn_implementation` as kwargs, respect it (it will be applied recursively on subconfigs)
  2003. if "attn_implementation" in kwargs:
  2004. config._attn_implementation = kwargs.pop("attn_implementation")
  2005. if is_deepspeed_zero3_enabled() and not _is_quantized and not _is_ds_init_called:
  2006. logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
  2007. # this immediately partitions the model across all gpus, to avoid the overhead in time
  2008. # and memory copying it on CPU or each GPU first
  2009. import deepspeed
  2010. init_contexts = [deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()]
  2011. with ContextManagers(init_contexts):
  2012. model = cls(config, **kwargs)
  2013. else:
  2014. model = cls(config, **kwargs)
  2015. # restore default dtype if it was modified
  2016. if dtype_orig is not None:
  2017. torch.set_default_dtype(dtype_orig)
  2018. return model
  2019. @classmethod
  2020. def _set_default_dtype(cls, dtype: torch.dtype) -> torch.dtype:
  2021. """
  2022. Change the default dtype and return the previous one. This is needed when wanting to instantiate the model
  2023. under specific dtype.
  2024. Args:
  2025. dtype (`torch.dtype`):
  2026. a floating dtype to set to.
  2027. Returns:
  2028. `torch.dtype`: the original `dtype` that can be used to restore `torch.set_default_dtype(dtype)` if it was
  2029. modified. If it wasn't, returns `None`.
  2030. Note `set_default_dtype` currently only works with floating-point types and asserts if for example,
  2031. `torch.int64` is passed. So if a non-float `dtype` is passed this functions will throw an exception.
  2032. """
  2033. if not dtype.is_floating_point:
  2034. raise ValueError(
  2035. f"Can't instantiate {cls.__name__} model under dtype={dtype} since it is not a floating point dtype"
  2036. )
  2037. logger.info(f"Instantiating {cls.__name__} model under default dtype {dtype}.")
  2038. dtype_orig = torch.get_default_dtype()
  2039. torch.set_default_dtype(dtype)
  2040. return dtype_orig
  2041. @property
  2042. def base_model(self) -> nn.Module:
  2043. """
  2044. `torch.nn.Module`: The main body of the model.
  2045. """
  2046. return getattr(self, self.base_model_prefix, self)
  2047. @classmethod
  2048. def can_generate(cls) -> bool:
  2049. """
  2050. Returns whether this model can generate sequences with `.generate()` from the `GenerationMixin`.
  2051. Under the hood, on classes where this function returns True, some generation-specific changes are triggered:
  2052. for instance, the model instance will have a populated `generation_config` attribute.
  2053. Returns:
  2054. `bool`: Whether this model can generate sequences with `.generate()`.
  2055. """
  2056. # Directly inherits `GenerationMixin` -> can generate
  2057. if "GenerationMixin" in str(cls.__bases__):
  2058. return True
  2059. # The class inherits from a class that can generate (recursive check) -> can generate
  2060. for base in cls.__bases__:
  2061. if not hasattr(base, "can_generate"):
  2062. continue
  2063. if "PreTrainedModel" not in str(base) and base.can_generate():
  2064. return True
  2065. # Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
  2066. # was how we detected whether a model could generate.
  2067. if hasattr(cls, "prepare_inputs_for_generation"): # implicit: doesn't inherit `GenerationMixin`
  2068. logger.warning(
  2069. f"{cls.__name__} has generative capabilities, as `prepare_inputs_for_generation` is explicitly "
  2070. "defined. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, "
  2071. "`PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability "
  2072. "to call `generate` and other related functions."
  2073. "\n - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the "
  2074. "model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes"
  2075. "\n - If you are the owner of the model architecture code, please modify your model class such that "
  2076. "it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception)."
  2077. "\n - If you are not the owner of the model architecture class, please contact the model code owner "
  2078. "to update it."
  2079. )
  2080. # Otherwise, can't generate
  2081. return False
  2082. def _flash_attn_2_can_dispatch(self, is_init_check: bool = False) -> bool:
  2083. """
  2084. Check the availability of Flash Attention 2 for a given model.
  2085. Args:
  2086. is_init_check (`bool`, *optional*):
  2087. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  2088. fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
  2089. BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
  2090. before instantiating the full models if we know that the model does not support the requested attention.
  2091. """
  2092. dtype = self.config.dtype
  2093. # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
  2094. if not (self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False)):
  2095. raise ValueError(
  2096. f"{self.__class__.__name__} does not support Flash Attention 2.0 yet. Please request to add support where"
  2097. f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
  2098. " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
  2099. )
  2100. if not is_flash_attn_2_available():
  2101. preface = "FlashAttention2 has been toggled on, but it cannot be used due to the following error:"
  2102. install_message = "Please refer to the documentation of https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2 to install Flash Attention 2."
  2103. # package `flash-attn` can not be installed on Ascend NPU, following validation logics can be ignored.
  2104. if is_torch_npu_available():
  2105. logger.info("Detect using FlashAttention2 on Ascend NPU.")
  2106. return True
  2107. if importlib.util.find_spec("flash_attn") is None:
  2108. raise ImportError(f"{preface} the package flash_attn seems to be not installed. {install_message}")
  2109. else:
  2110. # Check FA2 installed version compatibility
  2111. flash_attention_version = version.parse(importlib.metadata.version("flash_attn"))
  2112. if torch.version.cuda:
  2113. if flash_attention_version < version.parse("2.1.0"):
  2114. raise ImportError(
  2115. f"{preface} you need flash_attn package version to be greater or equal than 2.1.0. Detected version {flash_attention_version}. {install_message}"
  2116. )
  2117. elif not torch.cuda.is_available():
  2118. raise ValueError(
  2119. f"{preface} Flash Attention 2 is not available on CPU. Please make sure torch can access a CUDA device."
  2120. )
  2121. else:
  2122. raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
  2123. elif torch.version.hip:
  2124. if flash_attention_version < version.parse("2.0.4"):
  2125. raise ImportError(
  2126. f"{preface} you need flash_attn package version to be greater or equal than 2.0.4. Detected version {flash_attention_version}. {install_message}"
  2127. )
  2128. else:
  2129. raise ImportError(f"{preface} Flash Attention 2 is not available. {install_message}")
  2130. if dtype is None:
  2131. logger.warning_once(
  2132. "You are attempting to use Flash Attention 2 without specifying a torch dtype. This might lead to unexpected behaviour"
  2133. )
  2134. elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]:
  2135. logger.warning_once(
  2136. "Flash Attention 2 only supports torch.float16 and torch.bfloat16 dtypes, but"
  2137. f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
  2138. ' or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", dtype=torch.float16)`'
  2139. )
  2140. # With the early check, the parameters are not yet initialized correctly
  2141. if not is_init_check:
  2142. if getattr(self, "use_bettertransformer", False):
  2143. raise ValueError(
  2144. "Flash Attention 2 and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
  2145. )
  2146. param_devices = list({param.device for param in self.parameters()})
  2147. if len(param_devices) == 1 and param_devices[0].type == "cpu":
  2148. if torch.cuda.is_available():
  2149. logger.warning_once(
  2150. "You are attempting to use Flash Attention 2 with a model not initialized on GPU. Make sure to move the model to GPU"
  2151. " after initializing it on CPU with `model.to('cuda')`."
  2152. )
  2153. elif is_torch_mlu_available():
  2154. logger.warning_once(
  2155. "You are attempting to use Flash Attention 2 with a model not initialized on MLU. Make sure to move the model to MLU"
  2156. " after initializing it on CPU with `model.to('mlu')`."
  2157. )
  2158. else:
  2159. raise ValueError(
  2160. "You are attempting to use Flash Attention 2 with a model not initialized on GPU and with no GPU available. "
  2161. "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
  2162. "or initialising the model on CPU and then moving it to GPU."
  2163. )
  2164. # If no error raise by this point, we can return `True`
  2165. return True
  2166. def _flash_attn_3_can_dispatch(self, is_init_check: bool = False) -> bool:
  2167. """
  2168. Check the availability of Flash Attention 3 for a given model.
  2169. Args:
  2170. is_init_check (`bool`, *optional*):
  2171. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  2172. fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
  2173. BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
  2174. before instantiating the full models if we know that the model does not support the requested attention.
  2175. """
  2176. dtype = self.config.dtype
  2177. if not self._supports_flash_attn:
  2178. raise ValueError(
  2179. f"{self.__class__.__name__} does not support Flash Attention 3 yet. Please request to add support where"
  2180. f" the model is hosted, on its model hub page: https://huggingface.co/{self.config._name_or_path}/discussions/new"
  2181. " or in the Transformers GitHub repo: https://github.com/huggingface/transformers/issues/new"
  2182. )
  2183. if not is_flash_attn_3_available():
  2184. preface = "FlashAttention3 has been toggled on, but it cannot be used due to the following error:"
  2185. if importlib.util.find_spec("flash_attn_3") is None:
  2186. raise ImportError(f"{preface} the package flash_attn_3 seems to be not installed.")
  2187. if torch.cuda.is_available():
  2188. major, _ = torch.cuda.get_device_capability()
  2189. if major < 9:
  2190. raise ValueError(
  2191. f"{preface} Flash Attention 3 requires compute capability >= 9.0, but found {torch.cuda.get_device_capability()} with compute capability {major}.0."
  2192. )
  2193. else:
  2194. raise ImportError(f"{preface} Flash Attention 3 is not available.")
  2195. else:
  2196. raise ValueError(
  2197. f"{preface} Flash Attention 3 is not available on CPU. Please make sure torch can access a CUDA device."
  2198. )
  2199. if dtype is None:
  2200. logger.warning_once(
  2201. "You are attempting to use Flash Attention 3 without specifying a torch dtype. This might lead to unexpected behaviour"
  2202. )
  2203. elif dtype is not None and dtype not in [torch.float16, torch.bfloat16]:
  2204. logger.warning_once(
  2205. "Flash Attention 3 only supports torch.float16 and torch.bfloat16 dtypes, but"
  2206. f" the current dype in {self.__class__.__name__} is {dtype}. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator,"
  2207. ' or load the model with the `dtype` argument. Example: `model = AutoModel.from_pretrained("meta-llama/Llama-3.2-1B", attn_implementation="flash_attention_3", dtype=torch.float16)`'
  2208. )
  2209. if getattr(self.config, "alibi", False) or getattr(self.config, "use_alibi", False):
  2210. raise ValueError("Model is configured to use ALiBi, which is not supported by Flash Attention 3.")
  2211. # Check for attention dropout, which is incompatible with FA3
  2212. if hasattr(self.config, "attention_dropout") and self.config.attention_dropout > 0:
  2213. raise ValueError(
  2214. f"Model has attention_dropout={self.config.attention_dropout}, which is not supported by Flash Attention 3."
  2215. )
  2216. # With the early check, the parameters are not yet initialized correctly
  2217. if not is_init_check:
  2218. param_devices = list({param.device for param in self.parameters()})
  2219. if len(param_devices) == 1 and param_devices[0].type == "cpu":
  2220. if torch.cuda.is_available():
  2221. logger.warning_once(
  2222. "You are attempting to use Flash Attention 3 with a model not initialized on GPU. Make sure to move the model to GPU"
  2223. " after initializing it on CPU with `model.to('cuda')`."
  2224. )
  2225. else:
  2226. raise ValueError(
  2227. "You are attempting to use Flash Attention 3 with a model not initialized on GPU and with no GPU available. "
  2228. "This is not supported yet. Please make sure to have access to a GPU and either initialise the model on a GPU by passing a device_map "
  2229. "or initialising the model on CPU and then moving it to GPU."
  2230. )
  2231. return True
  2232. def _sdpa_can_dispatch(self, is_init_check: bool = False) -> bool:
  2233. """
  2234. Check the availability of SDPA for a given model.
  2235. Args:
  2236. is_init_check (`bool`, *optional*):
  2237. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  2238. fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
  2239. BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
  2240. before instantiating the full models if we know that the model does not support the requested attention.
  2241. """
  2242. if not self._supports_sdpa:
  2243. raise ValueError(
  2244. f"{self.__class__.__name__} does not support an attention implementation through torch.nn.functional.scaled_dot_product_attention yet."
  2245. " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/28005. If you believe"
  2246. ' this error is a bug, please open an issue in Transformers GitHub repository and load your model with the argument `attn_implementation="eager"` meanwhile. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
  2247. )
  2248. if (
  2249. torch.version.hip is not None
  2250. and torch.cuda.device_count() > 1
  2251. and version.parse(torch.__version__) < version.parse("2.4.1")
  2252. ):
  2253. logger.warning_once(
  2254. "Using the `SDPA` attention implementation on multi-gpu setup with ROCM may lead to performance issues due to the FA backend. Disabling it to use alternative backends."
  2255. )
  2256. torch.backends.cuda.enable_flash_sdp(False)
  2257. if not is_init_check:
  2258. if getattr(self, "use_bettertransformer", False):
  2259. raise ValueError(
  2260. "SDPA and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
  2261. )
  2262. return True
  2263. def _flex_attn_can_dispatch(self, is_init_check: bool = False) -> bool:
  2264. """
  2265. Check the availability of Flex Attention for a given model.
  2266. Args:
  2267. is_init_check (`bool`, *optional*):
  2268. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  2269. fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
  2270. BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
  2271. before instantiating the full models if we know that the model does not support the requested attention.
  2272. """
  2273. if not self._supports_flex_attn:
  2274. raise ValueError(
  2275. f"{self.__class__.__name__} does not support an attention implementation through torch's flex_attention."
  2276. " Please request the support for this architecture: https://github.com/huggingface/transformers/issues/34809."
  2277. " If you believe this error is a bug, please open an issue in Transformers GitHub repository"
  2278. ' and load your model with the argument `attn_implementation="eager"` meanwhile.'
  2279. ' Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="eager")`'
  2280. )
  2281. if not is_torch_flex_attn_available():
  2282. raise ImportError(
  2283. "PyTorch Flex Attention requirements in Transformers are not met. Please install torch>=2.5.0."
  2284. )
  2285. if not is_init_check:
  2286. if getattr(self, "use_bettertransformer", False):
  2287. raise ValueError(
  2288. "FlexAttention and BetterTransformer API are not compatible. Please make sure to disable BetterTransformers by doing model.reverse_bettertransformer()"
  2289. )
  2290. # If no error raise by this point, we can return `True`
  2291. return True
  2292. def _check_and_adjust_attn_implementation(
  2293. self, attn_implementation: Optional[str], is_init_check: bool = False
  2294. ) -> str:
  2295. """
  2296. Check that the `attn_implementation` exists and is supported by the models, and try to get the kernel from hub if
  2297. it matches hf kernels pattern.
  2298. Args:
  2299. attn_implementation (`str` or `None`):
  2300. The attention implementation to check for existence/validity.
  2301. is_init_check (`bool`, *optional*):
  2302. Whether this check is performed early, i.e. at __init__ time, or later when the model and its weights are
  2303. fully instantiated. This is needed as we also check the devices of the weights, and/or if the model uses
  2304. BetterTransformer, which are only available later after __init__. This allows to raise proper exceptions early
  2305. before instantiating the full models if we know that the model does not support the requested attention.
  2306. Returns:
  2307. `str`: The final attention implementation to use, including potential fallbacks from sdpa to eager, or from
  2308. None to sdpa (to potentially eager).
  2309. """
  2310. applicable_attn_implementation = attn_implementation
  2311. # If FA not installed, do not fail but use kernels instead
  2312. if (
  2313. attn_implementation is not None
  2314. and attn_implementation.startswith("flash_attention")
  2315. and self._supports_flash_attn
  2316. and not (is_flash_attn_2_available() or is_flash_attn_3_available())
  2317. and is_kernels_available()
  2318. ):
  2319. if attn_implementation.endswith("2"):
  2320. applicable_attn_implementation = "kernels-community/flash-attn"
  2321. else:
  2322. applicable_attn_implementation = "kernels-community/vllm-flash-attn3"
  2323. if is_kernel(applicable_attn_implementation):
  2324. try:
  2325. load_and_register_kernel(applicable_attn_implementation)
  2326. # log that we used kernel fallback if successful
  2327. if attn_implementation.startswith("flash_attention"):
  2328. logger.warning_once(
  2329. f"You do not have `flash_attn` installed, using `{applicable_attn_implementation}` "
  2330. "from the `kernels` library instead!"
  2331. )
  2332. except Exception as e:
  2333. # raise the proper exception for requested flash attention
  2334. if attn_implementation.startswith("flash_attention"):
  2335. if attn_implementation.endswith("2"):
  2336. self._flash_attn_2_can_dispatch()
  2337. else:
  2338. self._flash_attn_3_can_dispatch()
  2339. # error properly out if a kernel was specifically requested
  2340. raise e
  2341. else:
  2342. applicable_attn_implementation = self.get_correct_attn_implementation(
  2343. applicable_attn_implementation, is_init_check
  2344. )
  2345. # preload flash attention here to allow compile with fullgraph
  2346. if applicable_attn_implementation.startswith("flash_attention"):
  2347. lazy_import_flash_attention(applicable_attn_implementation, force_import=True)
  2348. return applicable_attn_implementation
  2349. def get_correct_attn_implementation(self, requested_attention: Optional[str], is_init_check: bool = False) -> str:
  2350. applicable_attention = "sdpa" if requested_attention is None else requested_attention
  2351. if applicable_attention not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
  2352. message = (
  2353. f'Specified `attn_implementation="{applicable_attention}"` is not supported. The only possible arguments are '
  2354. '`attn_implementation="eager"`'
  2355. )
  2356. # check `supports_flash_attn_2` for BC with custom code. TODO: remove after a few releases
  2357. if self._supports_flash_attn or getattr(self, "_supports_flash_attn_2", False):
  2358. message += ', `"attn_implementation=flash_attention_3"`, `"attn_implementation=flash_attention_2"`'
  2359. if self._supports_sdpa:
  2360. message += ', `"attn_implementation=sdpa"'
  2361. if self._supports_flex_attn:
  2362. message += ', `"attn_implementation=flex_attention"`'
  2363. raise ValueError(message + ".")
  2364. # Perform relevant checks
  2365. if applicable_attention == "flash_attention_2":
  2366. self._flash_attn_2_can_dispatch(is_init_check)
  2367. elif applicable_attention == "flash_attention_3":
  2368. self._flash_attn_3_can_dispatch(is_init_check)
  2369. elif applicable_attention == "flex_attention":
  2370. self._flex_attn_can_dispatch(is_init_check)
  2371. elif applicable_attention == "sdpa":
  2372. # Sdpa is the default, so we try it and fallback to eager otherwise when not possible
  2373. try:
  2374. self._sdpa_can_dispatch(is_init_check)
  2375. except (ValueError, ImportError) as e:
  2376. if requested_attention == "sdpa":
  2377. raise e
  2378. applicable_attention = "eager"
  2379. return applicable_attention
  2380. @classmethod
  2381. def _can_set_attn_implementation(cls) -> bool:
  2382. """Detect whether the class supports setting its attention implementation dynamically. It is an ugly check based on
  2383. opening the file, but avoids maintaining yet another property flag.
  2384. """
  2385. class_file = sys.modules[cls.__module__].__file__
  2386. with open(class_file, "r") as f:
  2387. code = f.read()
  2388. # heuristic -> if we find those patterns, the model uses the correct interface
  2389. if re.search(r"class \w+Attention\(nn.Module\)", code):
  2390. return (
  2391. "eager_attention_forward" in code
  2392. and "ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]" in code
  2393. )
  2394. else:
  2395. # If no attention layer, assume `True`. Most probably a multimodal model or inherits from existing models
  2396. return True
  2397. def set_attn_implementation(self, attn_implementation: Union[str, dict]):
  2398. """
  2399. Set the requested `attn_implementation` for this model.
  2400. Args:
  2401. attn_implementation (`str` or `dict`):
  2402. The attention implementation to set for this model. It can be either a `str`, in which case it will be
  2403. dispatched to all submodels if relevant, or a `dict` where keys are the sub_configs name, in which case each
  2404. submodel will dispatch the corresponding value.
  2405. """
  2406. requested_implementation = (
  2407. attn_implementation
  2408. if not isinstance(attn_implementation, dict)
  2409. else attn_implementation.get("", self.config._attn_implementation)
  2410. )
  2411. # At this point, the model was already instantiated, so instead of crashing on bad value, let's simply
  2412. # warn the user that the requested value is not working
  2413. if requested_implementation != self.config._attn_implementation:
  2414. # In this case, raise
  2415. if not self._can_set_attn_implementation():
  2416. logger.warning(
  2417. f"{self.__class__.__name__} does not support setting its attention implementation dynamically, because it "
  2418. "does not follow the functional approach based on AttentionInterface "
  2419. "(see https://huggingface.co/docs/transformers/en/attention_interface)"
  2420. )
  2421. else:
  2422. requested_implementation = self._check_and_adjust_attn_implementation(
  2423. requested_implementation, is_init_check=False
  2424. )
  2425. # Apply the change (on the internal attr, to avoid setting it recursively)
  2426. self.config._attn_implementation_internal = requested_implementation
  2427. # Apply it to all submodels as well
  2428. for submodule in self.modules():
  2429. # We found a submodel (which is not self) with a different config (otherwise, it may be the same "actual model",
  2430. # e.g. ForCausalLM has a Model inside, but no need to check it again)
  2431. if (
  2432. submodule is not self
  2433. and isinstance(submodule, PreTrainedModel)
  2434. and submodule.config.__class__ != self.config.__class__
  2435. # If it was already changed, no need to do it again
  2436. and not hasattr(submodule.config, "_attn_was_changed")
  2437. ):
  2438. # In this case, warn and skip
  2439. if not submodule._can_set_attn_implementation():
  2440. logger.warning(
  2441. f"{submodule.__class__.__name__} does not support setting its attention implementation dynamically, because it "
  2442. "does not follow the functional approach based on AttentionInterface "
  2443. "(see https://huggingface.co/docs/transformers/en/attention_interface)"
  2444. )
  2445. # Set the attn on the submodule
  2446. else:
  2447. sub_implementation = requested_implementation
  2448. if isinstance(attn_implementation, dict):
  2449. for subconfig_key in self.config.sub_configs:
  2450. # We need to check for exact object match here, with `is`
  2451. if getattr(self.config, subconfig_key) is submodule.config:
  2452. sub_implementation = attn_implementation.get(
  2453. subconfig_key, submodule.config._attn_implementation
  2454. )
  2455. break
  2456. # Check the module can use correctly, otherwise we raise an error if requested attention can't be set for submodule
  2457. sub_implementation = submodule.get_correct_attn_implementation(sub_implementation)
  2458. submodule.config._attn_implementation_internal = sub_implementation
  2459. # Still add it as "changed" even if it was skipped, as we would otherwise try to set it in the dark afterwards
  2460. # We need to set it on the config itself, to differentiate 2 subconfigs of the same __class__ potentially
  2461. submodule.config._attn_was_changed = True
  2462. # We need this as some old and badly designed models use subconfigs without declaring the corresponding modules as PreTrainedModel
  2463. for subconfig_key in self.config.sub_configs:
  2464. subconfig = getattr(self.config, subconfig_key)
  2465. sub_implementation = (
  2466. requested_implementation
  2467. if not isinstance(attn_implementation, dict)
  2468. else attn_implementation.get(subconfig_key, subconfig._attn_implementation)
  2469. )
  2470. # This means we did not perform any check above for this particular subconfig -> set it in the dark if it is registered
  2471. if (
  2472. not hasattr(subconfig, "_attn_was_changed")
  2473. # If it's already the same, then no need to enter here and raise warnings
  2474. and sub_implementation != subconfig._attn_implementation
  2475. ):
  2476. if sub_implementation not in ["eager"] + ALL_ATTENTION_FUNCTIONS.valid_keys():
  2477. raise ValueError(
  2478. f'Specified `attn_implementation="{sub_implementation}"` is not supported for {subconfig_key}. '
  2479. 'The only possible arguments are "eager" (manual attention implementation)'
  2480. f"or one of the following: {list(ALL_ATTENTION_FUNCTIONS.valid_keys())}"
  2481. )
  2482. subconfig._attn_implementation_internal = sub_implementation
  2483. logger.warning(
  2484. f"We set the attention implementation for the sub-config `{subconfig_key}` to `{sub_implementation}` "
  2485. "without finding the associated sub-model. For this reason we could not check if the model supports it. "
  2486. "You may encounter undefined behavior."
  2487. )
  2488. # Unset the attribute in this case, to avoid issues in the future
  2489. else:
  2490. if hasattr(subconfig, "_attn_was_changed"):
  2491. del subconfig._attn_was_changed
  2492. def enable_input_require_grads(self):
  2493. """
  2494. Enables the gradients for the input embeddings. This is useful for fine-tuning adapter weights while keeping
  2495. the model weights fixed.
  2496. """
  2497. def make_inputs_require_grads(module, input, output):
  2498. output.requires_grad_(True)
  2499. self._require_grads_hook = self.get_input_embeddings().register_forward_hook(make_inputs_require_grads)
  2500. def disable_input_require_grads(self):
  2501. """
  2502. Removes the `_require_grads_hook`.
  2503. """
  2504. self._require_grads_hook.remove()
  2505. def get_decoder(self):
  2506. """
  2507. Best-effort lookup of the *decoder* module.
  2508. Order of attempts (covers ~85 % of current usages):
  2509. 1. `self.decoder`
  2510. 2. `self.model` (many wrappers store the decoder here)
  2511. 3. `self.model.get_decoder()` (nested wrappers)
  2512. 4. fallback: raise for the few exotic models that need a bespoke rule
  2513. """
  2514. if hasattr(self, "decoder"):
  2515. return self.decoder
  2516. if hasattr(self, "model"):
  2517. inner = self.model
  2518. # See: https://github.com/huggingface/transformers/issues/40815
  2519. if hasattr(inner, "get_decoder") and type(inner) is not type(self):
  2520. return inner.get_decoder()
  2521. return inner
  2522. # If this is a base transformer model (no decoder/model attributes), return self
  2523. # This handles cases like MistralModel which is itself the decoder
  2524. return self
  2525. def set_decoder(self, decoder):
  2526. """
  2527. Symmetric setter. Mirrors the lookup logic used in `get_decoder`.
  2528. """
  2529. if hasattr(self, "decoder"):
  2530. self.decoder = decoder
  2531. return
  2532. if hasattr(self, "model"):
  2533. inner = self.model
  2534. if hasattr(inner, "set_decoder"):
  2535. inner.set_decoder(decoder)
  2536. else:
  2537. self.model = decoder
  2538. return
  2539. return
  2540. def _init_weights(self, module):
  2541. """
  2542. Initialize the weights. This is quite general on purpose, in the spirit of what we usually do. For more complex
  2543. initialization scheme, it should be overridden by the derived `PreTrainedModel` class. In case a model adds an explicit
  2544. `nn.Parameter`, this method should also be overridden in order to initialize it correctly.
  2545. """
  2546. if hasattr(self.config, "initializer_range"):
  2547. std = self.config.initializer_range
  2548. else:
  2549. # 0.02 is the standard default value across the library
  2550. std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
  2551. if isinstance(module, (nn.Linear, nn.Conv1d, nn.Conv2d, nn.Conv3d, nn.ConvTranspose1d, nn.ConvTranspose2d)):
  2552. module.weight.data.normal_(mean=0.0, std=std)
  2553. if module.bias is not None:
  2554. module.bias.data.zero_()
  2555. elif isinstance(module, nn.Embedding):
  2556. module.weight.data.normal_(mean=0.0, std=std)
  2557. if module.padding_idx is not None:
  2558. module.weight.data[module.padding_idx].zero_()
  2559. elif isinstance(module, nn.MultiheadAttention):
  2560. # This uses torch's original init
  2561. module._reset_parameters()
  2562. # We cannot use `isinstance` on the RMSNorms or LayerNorms, as they usually are custom modules which change names
  2563. # between modelings (because they are prefixed with the model name)
  2564. elif (
  2565. isinstance(module, (nn.GroupNorm, nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d))
  2566. or "LayerNorm" in module.__class__.__name__
  2567. or "RMSNorm" in module.__class__.__name__
  2568. ):
  2569. # Norms can exist without weights (in which case they are None from torch primitives)
  2570. if hasattr(module, "weight") and module.weight is not None:
  2571. module.weight.data.fill_(1.0)
  2572. if hasattr(module, "bias") and module.bias is not None:
  2573. module.bias.data.zero_()
  2574. def _initialize_weights(self, module):
  2575. """
  2576. Initialize the weights if they are not already initialized.
  2577. """
  2578. if getattr(module, "_is_hf_initialized", False):
  2579. return
  2580. self._init_weights(module)
  2581. module._is_hf_initialized = True
  2582. @torch.no_grad()
  2583. def initialize_weights(self):
  2584. """
  2585. This is equivalent to calling `self.apply(self._initialize_weights)`, but correctly handles composite models.
  2586. This function dynamically dispatches the correct `init_weights` function to the modules as we advance in the
  2587. module graph along the recursion. It can handle an arbitrary number of sub-models. Without it, every composite
  2588. model would have to recurse a second time on all sub-models explicitly in the outer-most `_init_weights`, which
  2589. is extremely error prone and inefficient.
  2590. Note that the `torch.no_grad()` decorator is very important as well, as most of our `_init_weights` do not use
  2591. `torch.nn.init` functions (which are all no_grad by default), but simply do in-place ops such as
  2592. `module.weight.data.zero_()`.
  2593. """
  2594. if not hasattr(torch.nn.Module, "smart_apply"):
  2595. # This function is equivalent to `torch.nn.Module.apply`, except that it dynamically adjust the function
  2596. # to apply as we go down the graph
  2597. def smart_apply(self, fn):
  2598. for module in self.children():
  2599. # We found a sub-model: recursively dispatch its own init function now!
  2600. if isinstance(module, PreTrainedModel):
  2601. module.smart_apply(module._initialize_weights)
  2602. else:
  2603. module.smart_apply(fn)
  2604. fn(self)
  2605. return self
  2606. torch.nn.Module.smart_apply = smart_apply
  2607. # Let the magic happen with this simple call
  2608. self.smart_apply(self._initialize_weights)
  2609. def tie_embeddings_and_encoder_decoder(self):
  2610. """
  2611. If set in the config, tie the weights between the input embeddings and the output embeddings,
  2612. and the encoder and decoder.
  2613. If the `torchscript` flag is set in the configuration, can't handle parameter sharing so we are cloning the
  2614. weights instead.
  2615. """
  2616. if getattr(self.config.get_text_config(decoder=True), "tie_word_embeddings", True):
  2617. output_embeddings = self.get_output_embeddings()
  2618. if output_embeddings is not None:
  2619. self._tie_or_clone_weights(output_embeddings, self.get_input_embeddings())
  2620. if getattr(self.config, "is_encoder_decoder", False) and getattr(self.config, "tie_encoder_decoder", False):
  2621. if hasattr(self, self.base_model_prefix):
  2622. self = getattr(self, self.base_model_prefix)
  2623. tied_weights = self._tie_encoder_decoder_weights(
  2624. self.encoder, self.decoder, self.base_model_prefix, "encoder"
  2625. )
  2626. # Setting a dynamic variable instead of `_tied_weights_keys` because it's a class
  2627. # attributed not an instance member, therefore modifying it will modify the entire class
  2628. # Leading to issues on subsequent calls by different tests or subsequent calls.
  2629. self._dynamic_tied_weights_keys = tied_weights
  2630. def tie_weights(self):
  2631. """
  2632. Recursively (for all submodels) tie all the weights of the model.
  2633. """
  2634. # Note that `self` is included in `self.modules` so we also apply to current PreTrainedModel with this call
  2635. for module in self.modules():
  2636. # If it's a PreTrainedModel, may need to tie the embeddings and/or encoder/decoder weights
  2637. if isinstance(module, PreTrainedModel):
  2638. module.tie_embeddings_and_encoder_decoder()
  2639. # Additionally, if it has a custom `_tie_weights`, honor it
  2640. if hasattr(module, "_tie_weights"):
  2641. module._tie_weights()
  2642. @staticmethod
  2643. def _tie_encoder_decoder_weights(
  2644. encoder: nn.Module, decoder: nn.Module, base_model_prefix: str, base_encoder_name: str
  2645. ):
  2646. uninitialized_encoder_weights: list[str] = []
  2647. tied_weights: list[str] = []
  2648. if decoder.__class__ != encoder.__class__:
  2649. logger.info(
  2650. f"{decoder.__class__} and {encoder.__class__} are not equal. In this case make sure that all encoder"
  2651. " weights are correctly initialized."
  2652. )
  2653. def tie_encoder_to_decoder_recursively(
  2654. decoder_pointer: nn.Module,
  2655. encoder_pointer: nn.Module,
  2656. module_name: str,
  2657. base_encoder_name: str,
  2658. uninitialized_encoder_weights: list[str],
  2659. depth=0,
  2660. total_decoder_name="",
  2661. total_encoder_name="",
  2662. ):
  2663. assert isinstance(decoder_pointer, nn.Module) and isinstance(encoder_pointer, nn.Module), (
  2664. f"{decoder_pointer} and {encoder_pointer} have to be of type nn.Module"
  2665. )
  2666. if hasattr(decoder_pointer, "weight"):
  2667. assert hasattr(encoder_pointer, "weight")
  2668. encoder_pointer.weight = decoder_pointer.weight
  2669. tied_weights.append(f"{base_encoder_name}{total_encoder_name}.weight")
  2670. if hasattr(decoder_pointer, "bias"):
  2671. assert hasattr(encoder_pointer, "bias")
  2672. tied_weights.append(f"{base_encoder_name}{total_encoder_name}.bias")
  2673. encoder_pointer.bias = decoder_pointer.bias
  2674. return
  2675. encoder_modules = encoder_pointer._modules
  2676. decoder_modules = decoder_pointer._modules
  2677. if len(decoder_modules) > 0:
  2678. assert len(encoder_modules) > 0, (
  2679. f"Encoder module {encoder_pointer} does not match decoder module {decoder_pointer}"
  2680. )
  2681. all_encoder_weights = {module_name + "/" + sub_name for sub_name in encoder_modules}
  2682. encoder_layer_pos = 0
  2683. for name in decoder_modules:
  2684. if name.isdigit():
  2685. encoder_name = str(int(name) + encoder_layer_pos)
  2686. decoder_name = name
  2687. if not isinstance(decoder_modules[decoder_name], type(encoder_modules[encoder_name])) and len(
  2688. encoder_modules
  2689. ) != len(decoder_modules):
  2690. # this can happen if the name corresponds to the position in a list module list of layers
  2691. # in this case the decoder has added a cross-attention that the encoder does not have
  2692. # thus skip this step and subtract one layer pos from encoder
  2693. encoder_layer_pos -= 1
  2694. continue
  2695. elif name not in encoder_modules:
  2696. continue
  2697. elif depth > 500:
  2698. raise ValueError(
  2699. "Max depth of recursive function `tie_encoder_to_decoder` reached. It seems that there is"
  2700. " a circular dependency between two or more `nn.Modules` of your model."
  2701. )
  2702. else:
  2703. decoder_name = encoder_name = name
  2704. tie_encoder_to_decoder_recursively(
  2705. decoder_modules[decoder_name],
  2706. encoder_modules[encoder_name],
  2707. module_name + "/" + name,
  2708. base_encoder_name,
  2709. uninitialized_encoder_weights,
  2710. depth=depth + 1,
  2711. total_encoder_name=f"{total_encoder_name}.{encoder_name}",
  2712. total_decoder_name=f"{total_decoder_name}.{decoder_name}",
  2713. )
  2714. all_encoder_weights.remove(module_name + "/" + encoder_name)
  2715. uninitialized_encoder_weights += list(all_encoder_weights)
  2716. # tie weights recursively
  2717. tie_encoder_to_decoder_recursively(
  2718. decoder, encoder, base_model_prefix, base_encoder_name, uninitialized_encoder_weights
  2719. )
  2720. if len(uninitialized_encoder_weights) > 0:
  2721. logger.warning(
  2722. f"The following encoder weights were not tied to the decoder {uninitialized_encoder_weights}"
  2723. )
  2724. return tied_weights
  2725. def _tie_or_clone_weights(self, output_embeddings, input_embeddings):
  2726. """Tie or clone module weights depending of whether we are using TorchScript or not"""
  2727. if self.config.torchscript:
  2728. output_embeddings.weight = nn.Parameter(input_embeddings.weight.clone())
  2729. else:
  2730. output_embeddings.weight = input_embeddings.weight
  2731. # Passing hooks over to the embeddings if needed
  2732. # (currently limited to tensor parallel hooks and flags only)
  2733. if hasattr(input_embeddings, "_is_hooked") and getattr(input_embeddings, "_hf_tp_plan", None):
  2734. output_embeddings._is_hooked = input_embeddings._is_hooked
  2735. output_embeddings._hf_tp_plan = input_embeddings._hf_tp_plan
  2736. output_embeddings._forward_hooks = input_embeddings._forward_hooks
  2737. output_embeddings._forward_pre_hooks = input_embeddings._forward_pre_hooks
  2738. output_embeddings.__repr__ = (
  2739. lambda: f"{output_embeddings.__repr__()}\nTP Plan: {output_embeddings._hf_tp_plan}"
  2740. )
  2741. if getattr(output_embeddings, "bias", None) is not None:
  2742. output_embeddings.bias.data = nn.functional.pad(
  2743. output_embeddings.bias.data,
  2744. (
  2745. 0,
  2746. output_embeddings.weight.shape[0] - output_embeddings.bias.shape[0],
  2747. ),
  2748. "constant",
  2749. 0,
  2750. )
  2751. if hasattr(output_embeddings, "out_features") and hasattr(input_embeddings, "num_embeddings"):
  2752. output_embeddings.out_features = input_embeddings.num_embeddings
  2753. def _get_no_split_modules(self, device_map: str):
  2754. """
  2755. Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
  2756. get the underlying `_no_split_modules`.
  2757. Args:
  2758. device_map (`str`):
  2759. The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
  2760. Returns:
  2761. `list[str]`: List of modules that should not be split
  2762. """
  2763. _no_split_modules = set()
  2764. modules_to_check = [self]
  2765. while len(modules_to_check) > 0:
  2766. module = modules_to_check.pop(-1)
  2767. # if the module does not appear in _no_split_modules, we also check the children
  2768. if module.__class__.__name__ not in _no_split_modules:
  2769. if isinstance(module, PreTrainedModel):
  2770. if module._no_split_modules is None:
  2771. raise ValueError(
  2772. f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
  2773. "class needs to implement the `_no_split_modules` attribute."
  2774. )
  2775. else:
  2776. _no_split_modules = _no_split_modules | set(module._no_split_modules)
  2777. modules_to_check += list(module.children())
  2778. return list(_no_split_modules)
  2779. def resize_token_embeddings(
  2780. self,
  2781. new_num_tokens: Optional[int] = None,
  2782. pad_to_multiple_of: Optional[int] = None,
  2783. mean_resizing: bool = True,
  2784. ) -> nn.Embedding:
  2785. """
  2786. Resizes input token embeddings matrix of the model if `new_num_tokens != config.vocab_size`.
  2787. Takes care of tying weights embeddings afterwards if the model class has a `tie_weights()` method.
  2788. Arguments:
  2789. new_num_tokens (`int`, *optional*):
  2790. The new number of tokens in the embedding matrix. Increasing the size will add newly initialized
  2791. vectors at the end. Reducing the size will remove vectors from the end. If not provided or `None`, just
  2792. returns a pointer to the input tokens `torch.nn.Embedding` module of the model without doing anything.
  2793. pad_to_multiple_of (`int`, *optional*):
  2794. If set will pad the embedding matrix to a multiple of the provided value.If `new_num_tokens` is set to
  2795. `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
  2796. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  2797. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  2798. details about this, or help on choosing the correct value for resizing, refer to this guide:
  2799. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  2800. mean_resizing (`bool`):
  2801. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  2802. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  2803. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  2804. where the generated tokens' probabilities won't be affected by the added embeddings because initializing the new embeddings with the
  2805. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  2806. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2807. Return:
  2808. `torch.nn.Embedding`: Pointer to the input tokens Embeddings Module of the model.
  2809. """
  2810. model_embeds = self._resize_token_embeddings(new_num_tokens, pad_to_multiple_of, mean_resizing)
  2811. if new_num_tokens is None and pad_to_multiple_of is None:
  2812. return model_embeds
  2813. # Since we are basically reusing the same old embeddings with new weight values, gathering is required
  2814. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  2815. if is_deepspeed_zero3_enabled() and not is_quantized:
  2816. import deepspeed
  2817. with deepspeed.zero.GatheredParameters(model_embeds.weight, modifier_rank=None):
  2818. vocab_size = model_embeds.weight.shape[0]
  2819. else:
  2820. vocab_size = model_embeds.weight.shape[0]
  2821. # Update base model and current model config.
  2822. self.config.get_text_config().vocab_size = vocab_size
  2823. self.vocab_size = vocab_size
  2824. # Tie weights again if needed
  2825. self.tie_weights()
  2826. return model_embeds
  2827. def _resize_token_embeddings(self, new_num_tokens, pad_to_multiple_of=None, mean_resizing=True):
  2828. old_embeddings = self.get_input_embeddings()
  2829. new_embeddings = self._get_resized_embeddings(
  2830. old_embeddings, new_num_tokens, pad_to_multiple_of, mean_resizing
  2831. )
  2832. if hasattr(old_embeddings, "_hf_hook"):
  2833. hook = old_embeddings._hf_hook
  2834. add_hook_to_module(new_embeddings, hook)
  2835. old_embeddings_requires_grad = old_embeddings.weight.requires_grad
  2836. new_embeddings.requires_grad_(old_embeddings_requires_grad)
  2837. self.set_input_embeddings(new_embeddings)
  2838. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  2839. # Update new_num_tokens with the actual size of new_embeddings
  2840. if pad_to_multiple_of is not None:
  2841. if is_deepspeed_zero3_enabled() and not is_quantized:
  2842. import deepspeed
  2843. with deepspeed.zero.GatheredParameters(new_embeddings.weight, modifier_rank=None):
  2844. new_num_tokens = new_embeddings.weight.shape[0]
  2845. else:
  2846. new_num_tokens = new_embeddings.weight.shape[0]
  2847. # if word embeddings are not tied, make sure that lm head is resized as well
  2848. if (
  2849. self.get_output_embeddings() is not None
  2850. and not self.config.get_text_config(decoder=True).tie_word_embeddings
  2851. ):
  2852. old_lm_head = self.get_output_embeddings()
  2853. if isinstance(old_lm_head, torch.nn.Embedding):
  2854. new_lm_head = self._get_resized_embeddings(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
  2855. else:
  2856. new_lm_head = self._get_resized_lm_head(old_lm_head, new_num_tokens, mean_resizing=mean_resizing)
  2857. if hasattr(old_lm_head, "_hf_hook"):
  2858. hook = old_lm_head._hf_hook
  2859. add_hook_to_module(new_lm_head, hook)
  2860. old_lm_head_requires_grad = old_lm_head.weight.requires_grad
  2861. new_lm_head.requires_grad_(old_lm_head_requires_grad)
  2862. self.set_output_embeddings(new_lm_head)
  2863. return self.get_input_embeddings()
  2864. def _get_resized_embeddings(
  2865. self,
  2866. old_embeddings: nn.Embedding,
  2867. new_num_tokens: Optional[int] = None,
  2868. pad_to_multiple_of: Optional[int] = None,
  2869. mean_resizing: bool = True,
  2870. ) -> nn.Embedding:
  2871. """
  2872. Build a resized Embedding Module from a provided token Embedding Module. Increasing the size will add newly
  2873. initialized vectors at the end. Reducing the size will remove vectors from the end
  2874. Args:
  2875. old_embeddings (`torch.nn.Embedding`):
  2876. Old embeddings to be resized.
  2877. new_num_tokens (`int`, *optional*):
  2878. New number of tokens in the embedding matrix.
  2879. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  2880. vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
  2881. `torch.nn.Embedding` module of the model without doing anything.
  2882. pad_to_multiple_of (`int`, *optional*):
  2883. If set will pad the embedding matrix to a multiple of the provided value. If `new_num_tokens` is set to
  2884. `None` will just pad the embedding to a multiple of `pad_to_multiple_of`.
  2885. This is especially useful to enable the use of Tensor Cores on NVIDIA hardware with compute capability
  2886. `>= 7.5` (Volta), or on TPUs which benefit from having sequence lengths be a multiple of 128. For more
  2887. details about this, or help on choosing the correct value for resizing, refer to this guide:
  2888. https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc
  2889. mean_resizing (`bool`):
  2890. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  2891. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  2892. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  2893. where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
  2894. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  2895. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2896. Return:
  2897. `torch.nn.Embedding`: Pointer to the resized Embedding Module or the old Embedding Module if
  2898. `new_num_tokens` is `None`
  2899. """
  2900. if pad_to_multiple_of is not None:
  2901. if not isinstance(pad_to_multiple_of, int):
  2902. raise ValueError(
  2903. f"Asking to pad the embedding matrix to a multiple of `{pad_to_multiple_of}`, which is not and integer. Please make sure to pass an integer"
  2904. )
  2905. if new_num_tokens is None:
  2906. new_num_tokens = old_embeddings.weight.shape[0]
  2907. new_num_tokens = ((new_num_tokens + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
  2908. else:
  2909. logger.info(
  2910. "You are resizing the embedding layer without providing a `pad_to_multiple_of` parameter. This means that the new embedding"
  2911. f" dimension will be {new_num_tokens}. This might induce some performance reduction as *Tensor Cores* will not be available."
  2912. " For more details about this, or help on choosing the correct value for resizing, refer to this guide:"
  2913. " https://docs.nvidia.com/deeplearning/performance/dl-performance-matrix-multiplication/index.html#requirements-tc"
  2914. )
  2915. if new_num_tokens is None:
  2916. return old_embeddings
  2917. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  2918. if is_deepspeed_zero3_enabled() and not is_quantized:
  2919. import deepspeed
  2920. with deepspeed.zero.GatheredParameters(old_embeddings.weight, modifier_rank=None):
  2921. old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
  2922. else:
  2923. old_num_tokens, old_embedding_dim = old_embeddings.weight.size()
  2924. if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
  2925. return old_embeddings
  2926. if not isinstance(old_embeddings, nn.Embedding):
  2927. raise TypeError(
  2928. f"Old embeddings are of type {type(old_embeddings)}, which is not an instance of {nn.Embedding}. You"
  2929. " should either use a different resize function or make sure that `old_embeddings` are an instance of"
  2930. f" {nn.Embedding}."
  2931. )
  2932. # Build new embeddings
  2933. # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
  2934. # because the shape of the new embedding layer is used across various modeling files
  2935. # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
  2936. # to errors when training.
  2937. new_embeddings = nn.Embedding(
  2938. new_num_tokens,
  2939. old_embedding_dim,
  2940. device=old_embeddings.weight.device,
  2941. dtype=old_embeddings.weight.dtype,
  2942. )
  2943. if new_num_tokens > old_num_tokens and not mean_resizing:
  2944. # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
  2945. self._init_weights(new_embeddings)
  2946. elif new_num_tokens > old_num_tokens and mean_resizing:
  2947. # initialize new embeddings (in particular added tokens). The new embeddings will be initialized
  2948. # from a multivariate normal distribution that has old embeddings' mean and covariance.
  2949. # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  2950. logger.warning_once(
  2951. "The new embeddings will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
  2952. "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
  2953. "To disable this, use `mean_resizing=False`"
  2954. )
  2955. added_num_tokens = new_num_tokens - old_num_tokens
  2956. if is_deepspeed_zero3_enabled() and not is_quantized:
  2957. import deepspeed
  2958. with deepspeed.zero.GatheredParameters([old_embeddings.weight], modifier_rank=None):
  2959. self._init_added_embeddings_weights_with_mean(
  2960. old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
  2961. )
  2962. else:
  2963. self._init_added_embeddings_weights_with_mean(
  2964. old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
  2965. )
  2966. # Copy token embeddings from the previous weights
  2967. # numbers of tokens to copy
  2968. n = min(old_num_tokens, new_num_tokens)
  2969. if is_deepspeed_zero3_enabled() and not is_quantized:
  2970. import deepspeed
  2971. params = [old_embeddings.weight, new_embeddings.weight]
  2972. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  2973. new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
  2974. else:
  2975. new_embeddings.weight.data[:n, :] = old_embeddings.weight.data[:n, :]
  2976. # Replace weights in old_embeddings and return to maintain the same embedding type.
  2977. # This ensures correct functionality when a Custom Embedding class is passed as input.
  2978. # The input and output embedding types remain consistent. (c.f. https://github.com/huggingface/transformers/pull/31979)
  2979. if is_deepspeed_zero3_enabled() and not is_quantized:
  2980. import deepspeed
  2981. params = [old_embeddings.weight, new_embeddings.weight]
  2982. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  2983. old_embeddings.weight = new_embeddings.weight
  2984. old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
  2985. # If the new number of tokens is smaller than the original `padding_idx`, the `padding_idx`
  2986. # will be set to `None` in the resized embeddings.
  2987. if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
  2988. old_embeddings.padding_idx = None
  2989. else:
  2990. old_embeddings.weight.data = new_embeddings.weight.data
  2991. old_embeddings.num_embeddings = new_embeddings.weight.data.shape[0]
  2992. if old_embeddings.padding_idx is not None and (new_num_tokens - 1) < old_embeddings.padding_idx:
  2993. old_embeddings.padding_idx = None
  2994. return old_embeddings
  2995. def _get_resized_lm_head(
  2996. self,
  2997. old_lm_head: nn.Linear,
  2998. new_num_tokens: Optional[int] = None,
  2999. transposed: bool = False,
  3000. mean_resizing: bool = True,
  3001. ) -> nn.Linear:
  3002. """
  3003. Build a resized Linear Module from a provided old Linear Module. Increasing the size will add newly initialized
  3004. vectors at the end. Reducing the size will remove vectors from the end
  3005. Args:
  3006. old_lm_head (`torch.nn.Linear`):
  3007. Old lm head liner layer to be resized.
  3008. new_num_tokens (`int`, *optional*):
  3009. New number of tokens in the linear matrix.
  3010. Increasing the size will add newly initialized vectors at the end. Reducing the size will remove
  3011. vectors from the end. If not provided or `None`, just returns a pointer to the input tokens
  3012. `torch.nn.Linear` module of the model without doing anything. transposed (`bool`, *optional*, defaults
  3013. to `False`): Whether `old_lm_head` is transposed or not. If True `old_lm_head.size()` is `lm_head_dim,
  3014. vocab_size` else `vocab_size, lm_head_dim`.
  3015. mean_resizing (`bool`):
  3016. Whether to initialize the added embeddings from a multivariate normal distribution that has old embeddings' mean and
  3017. covariance or to initialize them with a normal distribution that has a mean of zero and std equals `config.initializer_range`.
  3018. Setting `mean_resizing` to `True` is useful when increasing the size of the embeddings of causal language models,
  3019. where the generated tokens' probabilities will not be affected by the added embeddings because initializing the new embeddings with the
  3020. old embeddings' mean will reduce the kl-divergence between the next token probability before and after adding the new embeddings.
  3021. Refer to this article for more information: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  3022. Return:
  3023. `torch.nn.Linear`: Pointer to the resized Linear Module or the old Linear Module if `new_num_tokens` is
  3024. `None`
  3025. """
  3026. if new_num_tokens is None:
  3027. return old_lm_head
  3028. is_quantized = hasattr(self, "hf_quantizer") and self.hf_quantizer is not None
  3029. if is_deepspeed_zero3_enabled() and not is_quantized:
  3030. import deepspeed
  3031. with deepspeed.zero.GatheredParameters(old_lm_head.weight, modifier_rank=None):
  3032. old_num_tokens, old_lm_head_dim = (
  3033. old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
  3034. )
  3035. else:
  3036. old_num_tokens, old_lm_head_dim = (
  3037. old_lm_head.weight.size() if not transposed else old_lm_head.weight.t().size()
  3038. )
  3039. if old_num_tokens == new_num_tokens and not is_deepspeed_zero3_enabled():
  3040. return old_lm_head
  3041. if not isinstance(old_lm_head, nn.Linear):
  3042. raise TypeError(
  3043. f"Old language model head is of type {type(old_lm_head)}, which is not an instance of {nn.Linear}. You"
  3044. " should either use a different resize function or make sure that `old_lm_head` are an instance of"
  3045. f" {nn.Linear}."
  3046. )
  3047. # Build new lm head
  3048. new_lm_head_shape = (old_lm_head_dim, new_num_tokens) if not transposed else (new_num_tokens, old_lm_head_dim)
  3049. has_new_lm_head_bias = old_lm_head.bias is not None
  3050. # When using DeepSpeed ZeRO-3, we shouldn't create new embeddings with DeepSpeed init
  3051. # because the shape of the new embedding layer is used across various modeling files
  3052. # as well as to update config vocab size. Shape will be 0 when using DeepSpeed init leading
  3053. # to errors when training.
  3054. new_lm_head = nn.Linear(
  3055. *new_lm_head_shape,
  3056. bias=has_new_lm_head_bias,
  3057. device=old_lm_head.weight.device,
  3058. dtype=old_lm_head.weight.dtype,
  3059. )
  3060. if new_num_tokens > old_num_tokens and not mean_resizing:
  3061. # initialize new embeddings (in particular added tokens) with a mean of 0 and std equals `config.initializer_range`.
  3062. self._init_weights(new_lm_head)
  3063. elif new_num_tokens > old_num_tokens and mean_resizing:
  3064. # initialize new lm_head weights (in particular added tokens). The new lm_head weights
  3065. # will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance.
  3066. # as described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html
  3067. logger.warning_once(
  3068. "The new lm_head weights will be initialized from a multivariate normal distribution that has old embeddings' mean and covariance. "
  3069. "As described in this article: https://nlp.stanford.edu/~johnhew/vocab-expansion.html. "
  3070. "To disable this, use `mean_resizing=False`"
  3071. )
  3072. added_num_tokens = new_num_tokens - old_num_tokens
  3073. if is_deepspeed_zero3_enabled() and not is_quantized:
  3074. import deepspeed
  3075. params = [old_lm_head.weight]
  3076. if has_new_lm_head_bias:
  3077. params += [old_lm_head.bias]
  3078. with deepspeed.zero.GatheredParameters(params, modifier_rank=None):
  3079. self._init_added_lm_head_weights_with_mean(
  3080. old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
  3081. )
  3082. if has_new_lm_head_bias:
  3083. self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
  3084. else:
  3085. self._init_added_lm_head_weights_with_mean(
  3086. old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens, transposed
  3087. )
  3088. if has_new_lm_head_bias:
  3089. self._init_added_lm_head_bias_with_mean(old_lm_head, new_lm_head, added_num_tokens)
  3090. num_tokens_to_copy = min(old_num_tokens, new_num_tokens)
  3091. if is_deepspeed_zero3_enabled() and not is_quantized:
  3092. import deepspeed
  3093. params = [old_lm_head.weight, old_lm_head.bias, new_lm_head.weight, new_lm_head.bias]
  3094. with deepspeed.zero.GatheredParameters(params, modifier_rank=0):
  3095. self._copy_lm_head_original_to_resized(
  3096. new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  3097. )
  3098. else:
  3099. self._copy_lm_head_original_to_resized(
  3100. new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  3101. )
  3102. return new_lm_head
  3103. def _init_added_embeddings_weights_with_mean(
  3104. self, old_embeddings, new_embeddings, old_embedding_dim, old_num_tokens, added_num_tokens
  3105. ):
  3106. old_embeddings_weight = old_embeddings.weight.data.to(torch.float32)
  3107. mean_embeddings = torch.mean(old_embeddings_weight, axis=0)
  3108. old_centered_embeddings = old_embeddings_weight - mean_embeddings
  3109. covariance = old_centered_embeddings.T @ old_centered_embeddings / old_num_tokens
  3110. # Check if the covariance is positive definite.
  3111. epsilon = 1e-9
  3112. is_covariance_psd = constraints.positive_definite.check(epsilon * covariance).all()
  3113. if is_covariance_psd:
  3114. # If covariances is positive definite, a distribution can be created. and we can sample new weights from it.
  3115. distribution = torch.distributions.multivariate_normal.MultivariateNormal(
  3116. mean_embeddings, covariance_matrix=epsilon * covariance
  3117. )
  3118. new_embeddings.weight.data[-1 * added_num_tokens :, :] = distribution.sample(
  3119. sample_shape=(added_num_tokens,)
  3120. ).to(old_embeddings.weight.dtype)
  3121. else:
  3122. # Otherwise, just initialize with the mean. because distribution will not be created.
  3123. new_embeddings.weight.data[-1 * added_num_tokens :, :] = (
  3124. mean_embeddings[None, :].repeat(added_num_tokens, 1).to(old_embeddings.weight.dtype)
  3125. )
  3126. def _init_added_lm_head_weights_with_mean(
  3127. self,
  3128. old_lm_head,
  3129. new_lm_head,
  3130. old_lm_head_dim,
  3131. old_num_tokens,
  3132. added_num_tokens,
  3133. transposed: bool = False,
  3134. ):
  3135. if transposed:
  3136. # Transpose to the desired shape for the function.
  3137. new_lm_head.weight.data = new_lm_head.weight.data.T
  3138. old_lm_head.weight.data = old_lm_head.weight.data.T
  3139. # The same initialization logic as Embeddings.
  3140. self._init_added_embeddings_weights_with_mean(
  3141. old_lm_head, new_lm_head, old_lm_head_dim, old_num_tokens, added_num_tokens
  3142. )
  3143. if transposed:
  3144. # Transpose again to the correct shape.
  3145. new_lm_head.weight.data = new_lm_head.weight.data.T
  3146. old_lm_head.weight.data = old_lm_head.weight.data.T
  3147. def _init_added_lm_head_bias_with_mean(self, old_lm_head, new_lm_head, added_num_tokens):
  3148. bias_mean = torch.mean(old_lm_head.bias.data, axis=0, dtype=torch.float32)
  3149. bias_std = torch.std(old_lm_head.bias.data, axis=0).to(torch.float32)
  3150. new_lm_head.bias.data[-1 * added_num_tokens :].normal_(mean=bias_mean, std=1e-9 * bias_std)
  3151. def _copy_lm_head_original_to_resized(
  3152. self, new_lm_head, old_lm_head, num_tokens_to_copy, transposed, has_new_lm_head_bias
  3153. ):
  3154. # Copy old lm head weights to new lm head
  3155. if not transposed:
  3156. new_lm_head.weight.data[:num_tokens_to_copy, :] = old_lm_head.weight.data[:num_tokens_to_copy, :]
  3157. else:
  3158. new_lm_head.weight.data[:, :num_tokens_to_copy] = old_lm_head.weight.data[:, :num_tokens_to_copy]
  3159. # Copy bias weights to new lm head
  3160. if has_new_lm_head_bias:
  3161. new_lm_head.bias.data[:num_tokens_to_copy] = old_lm_head.bias.data[:num_tokens_to_copy]
  3162. def resize_position_embeddings(self, new_num_position_embeddings: int):
  3163. raise NotImplementedError(
  3164. f"`resize_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  3165. f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
  3166. )
  3167. def get_position_embeddings(self) -> Union[nn.Embedding, tuple[nn.Embedding]]:
  3168. raise NotImplementedError(
  3169. f"`get_position_embeddings` is not implemented for {self.__class__}`. To implement it, you should "
  3170. f"overwrite this method in the class {self.__class__} in `modeling_{self.__class__.__module__}.py`"
  3171. )
  3172. def init_weights(self):
  3173. """
  3174. If needed prunes and maybe initializes weights. If using a custom `PreTrainedModel`, you need to implement any
  3175. initialization logic in `_init_weights`.
  3176. """
  3177. # Prune heads if needed
  3178. if self.config.pruned_heads:
  3179. self.prune_heads(self.config.pruned_heads)
  3180. if _init_weights:
  3181. # Initialize weights
  3182. self.initialize_weights()
  3183. # Tie weights should be skipped when not initializing all weights
  3184. # since from_pretrained(...) calls tie weights anyways
  3185. self.tie_weights()
  3186. def prune_heads(self, heads_to_prune: dict[int, list[int]]):
  3187. """
  3188. Prunes heads of the base model.
  3189. Arguments:
  3190. heads_to_prune (`dict[int, list[int]]`):
  3191. Dictionary with keys being selected layer indices (`int`) and associated values being the list of heads
  3192. to prune in said layer (list of `int`). For instance {1: [0, 2], 2: [2, 3]} will prune heads 0 and 2 on
  3193. layer 1 and heads 2 and 3 on layer 2.
  3194. """
  3195. # save new sets of pruned heads as union of previously stored pruned heads and newly pruned heads
  3196. for layer, heads in heads_to_prune.items():
  3197. union_heads = set(self.config.pruned_heads.get(layer, [])) | set(heads)
  3198. self.config.pruned_heads[layer] = list(union_heads) # Unfortunately we have to store it as list for JSON
  3199. self.base_model._prune_heads(heads_to_prune)
  3200. def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
  3201. """
  3202. Activates gradient checkpointing for the current model.
  3203. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
  3204. activations".
  3205. We pass the `__call__` method of the modules instead of `forward` because `__call__` attaches all the hooks of
  3206. the module. https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
  3207. Args:
  3208. gradient_checkpointing_kwargs (dict, *optional*):
  3209. Additional keyword arguments passed along to the `torch.utils.checkpoint.checkpoint` function.
  3210. """
  3211. if not self.supports_gradient_checkpointing:
  3212. raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
  3213. if gradient_checkpointing_kwargs is None:
  3214. gradient_checkpointing_kwargs = {"use_reentrant": True}
  3215. gradient_checkpointing_func = functools.partial(checkpoint, **gradient_checkpointing_kwargs)
  3216. # For old GC format (transformers < 4.35.0) for models that live on the Hub
  3217. # we will fall back to the overwritten `_set_gradient_checkpointing` method
  3218. _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
  3219. if not _is_using_old_format:
  3220. self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
  3221. else:
  3222. self.apply(partial(self._set_gradient_checkpointing, value=True))
  3223. logger.warning(
  3224. "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
  3225. "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
  3226. )
  3227. if getattr(self, "_hf_peft_config_loaded", False):
  3228. # When using PEFT + gradient checkpointing + Trainer we need to make sure the input has requires_grad=True
  3229. # we do it also on PEFT: https://github.com/huggingface/peft/blob/85013987aa82aa1af3da1236b6902556ce3e483e/src/peft/peft_model.py#L334
  3230. # When training with PEFT, only LoRA layers will have requires grad set to True, but the output of frozen layers need to propagate
  3231. # the gradients to make sure the gradient flows.
  3232. self.enable_input_require_grads()
  3233. def _set_gradient_checkpointing(self, enable: bool = True, gradient_checkpointing_func: Callable = checkpoint):
  3234. is_gradient_checkpointing_set = False
  3235. # Apply it on the top-level module in case the top-level modules supports it
  3236. # for example, LongT5Stack inherits from `PreTrainedModel`.
  3237. if hasattr(self, "gradient_checkpointing"):
  3238. self._gradient_checkpointing_func = gradient_checkpointing_func
  3239. self.gradient_checkpointing = enable
  3240. is_gradient_checkpointing_set = True
  3241. for module in self.modules():
  3242. if hasattr(module, "gradient_checkpointing"):
  3243. module._gradient_checkpointing_func = gradient_checkpointing_func
  3244. module.gradient_checkpointing = enable
  3245. is_gradient_checkpointing_set = True
  3246. if not is_gradient_checkpointing_set:
  3247. raise ValueError(
  3248. f"{self.__class__.__name__} is not compatible with gradient checkpointing. Make sure all the architecture support it by setting a boolean attribute"
  3249. " `gradient_checkpointing` to modules of the model that uses checkpointing."
  3250. )
  3251. def gradient_checkpointing_disable(self):
  3252. """
  3253. Deactivates gradient checkpointing for the current model.
  3254. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
  3255. activations".
  3256. """
  3257. if self.supports_gradient_checkpointing:
  3258. # For old GC format (transformers < 4.35.0) for models that live on the Hub
  3259. # we will fall back to the overwritten `_set_gradient_checkpointing` method
  3260. _is_using_old_format = "value" in inspect.signature(self._set_gradient_checkpointing).parameters
  3261. if not _is_using_old_format:
  3262. self._set_gradient_checkpointing(enable=False)
  3263. else:
  3264. logger.warning(
  3265. "You are using an old version of the checkpointing format that is deprecated (We will also silently ignore `gradient_checkpointing_kwargs` in case you passed it)."
  3266. "Please update to the new format on your modeling file. To use the new format, you need to completely remove the definition of the method `_set_gradient_checkpointing` in your model."
  3267. )
  3268. self.apply(partial(self._set_gradient_checkpointing, value=False))
  3269. if getattr(self, "_hf_peft_config_loaded", False):
  3270. self.disable_input_require_grads()
  3271. @property
  3272. def is_gradient_checkpointing(self) -> bool:
  3273. """
  3274. Whether gradient checkpointing is activated for this model or not.
  3275. Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
  3276. activations".
  3277. """
  3278. return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
  3279. def save_pretrained(
  3280. self,
  3281. save_directory: Union[str, os.PathLike],
  3282. is_main_process: bool = True,
  3283. state_dict: Optional[dict] = None,
  3284. save_function: Callable = torch.save,
  3285. push_to_hub: bool = False,
  3286. max_shard_size: Union[int, str] = "5GB",
  3287. safe_serialization: bool = True,
  3288. variant: Optional[str] = None,
  3289. token: Optional[Union[str, bool]] = None,
  3290. save_peft_format: bool = True,
  3291. **kwargs,
  3292. ):
  3293. """
  3294. Save a model and its configuration file to a directory, so that it can be re-loaded using the
  3295. [`~PreTrainedModel.from_pretrained`] class method.
  3296. Arguments:
  3297. save_directory (`str` or `os.PathLike`):
  3298. Directory to which to save. Will be created if it doesn't exist.
  3299. is_main_process (`bool`, *optional*, defaults to `True`):
  3300. Whether the process calling this is the main process or not. Useful when in distributed training like
  3301. TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
  3302. the main process to avoid race conditions.
  3303. state_dict (nested dictionary of `torch.Tensor`):
  3304. The state dictionary of the model to save. Will default to `self.state_dict()`, but can be used to only
  3305. save parts of the model or if special precautions need to be taken when recovering the state dictionary
  3306. of a model (like when using model parallelism).
  3307. save_function (`Callable`):
  3308. The function to use to save the state dictionary. Useful on distributed training like TPUs when one
  3309. need to replace `torch.save` by another method.
  3310. push_to_hub (`bool`, *optional*, defaults to `False`):
  3311. Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
  3312. repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
  3313. namespace).
  3314. max_shard_size (`int` or `str`, *optional*, defaults to `"5GB"`):
  3315. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
  3316. lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
  3317. We default it to 5GB in order for models to be able to run easily on free-tier google colab instances
  3318. without CPU OOM issues.
  3319. <Tip warning={true}>
  3320. If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
  3321. which will be bigger than `max_shard_size`.
  3322. </Tip>
  3323. safe_serialization (`bool`, *optional*, defaults to `True`):
  3324. Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
  3325. variant (`str`, *optional*):
  3326. If specified, weights are saved in the format pytorch_model.<variant>.bin.
  3327. token (`str` or `bool`, *optional*):
  3328. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  3329. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  3330. save_peft_format (`bool`, *optional*, defaults to `True`):
  3331. For backward compatibility with PEFT library, in case adapter weights are attached to the model, all
  3332. keys of the state dict of adapters needs to be prepended with `base_model.model`. Advanced users can
  3333. disable this behaviours by setting `save_peft_format` to `False`.
  3334. kwargs (`dict[str, Any]`, *optional*):
  3335. Additional key word arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
  3336. """
  3337. use_auth_token = kwargs.pop("use_auth_token", None)
  3338. ignore_metadata_errors = kwargs.pop("ignore_metadata_errors", False)
  3339. if use_auth_token is not None:
  3340. warnings.warn(
  3341. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  3342. FutureWarning,
  3343. )
  3344. if token is not None:
  3345. raise ValueError(
  3346. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  3347. )
  3348. token = use_auth_token
  3349. if token is not None:
  3350. kwargs["token"] = token
  3351. _hf_peft_config_loaded = getattr(self, "_hf_peft_config_loaded", False)
  3352. hf_quantizer = getattr(self, "hf_quantizer", None)
  3353. quantization_serializable = (
  3354. hf_quantizer is not None
  3355. and isinstance(hf_quantizer, HfQuantizer)
  3356. and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
  3357. )
  3358. if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
  3359. raise ValueError(
  3360. f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
  3361. " the logger on the traceback to understand the reason why the quantized model is not serializable."
  3362. )
  3363. if "save_config" in kwargs:
  3364. warnings.warn(
  3365. "`save_config` is deprecated and will be removed in v5 of Transformers. Use `is_main_process` instead."
  3366. )
  3367. is_main_process = kwargs.pop("save_config")
  3368. # we need to check against tp_size, not tp_plan, as tp_plan is substituted to the class one
  3369. if self._tp_size is not None and not is_huggingface_hub_greater_or_equal("0.31.4"):
  3370. raise ImportError(
  3371. "Saving a model with tensor parallelism requires `huggingface_hub` version 0.31.4 or higher."
  3372. )
  3373. if os.path.isfile(save_directory):
  3374. logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
  3375. return
  3376. os.makedirs(save_directory, exist_ok=True)
  3377. if push_to_hub:
  3378. commit_message = kwargs.pop("commit_message", None)
  3379. repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
  3380. create_pr = kwargs.pop("create_pr", False)
  3381. repo_id = self._create_repo(repo_id, **kwargs)
  3382. files_timestamps = self._get_files_timestamps(save_directory)
  3383. metadata = {}
  3384. if hf_quantizer is not None:
  3385. state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
  3386. metadata["format"] = "pt"
  3387. # Only save the model itself if we are using distributed training
  3388. model_to_save = unwrap_model(self)
  3389. # save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
  3390. # we currently don't use this setting automatically, but may start to use with v5
  3391. dtype = get_parameter_dtype(model_to_save)
  3392. model_to_save.config.dtype = str(dtype).split(".")[1]
  3393. # Attach architecture to the config
  3394. # When using FSDP2, unwrapping is a noop, so the model name doesn't change back to the original model name
  3395. model_to_save.config.architectures = [model_to_save.__class__.__name__.removeprefix("FSDP")]
  3396. # If we have a custom model, we copy the file defining it in the folder and set the attributes so it can be
  3397. # loaded from the Hub.
  3398. if self._auto_class is not None:
  3399. custom_object_save(self, save_directory, config=self.config)
  3400. # Save the config
  3401. if is_main_process:
  3402. if not _hf_peft_config_loaded:
  3403. # If the model config has set attributes that should be in the generation config, move them there.
  3404. misplaced_generation_parameters = model_to_save.config._get_non_default_generation_parameters()
  3405. if self.can_generate() and len(misplaced_generation_parameters) > 0:
  3406. warnings.warn(
  3407. "Moving the following attributes in the config to the generation config: "
  3408. f"{misplaced_generation_parameters}. You are seeing this warning because you've set "
  3409. "generation parameters in the model config, as opposed to in the generation config.",
  3410. UserWarning,
  3411. )
  3412. for param_name, param_value in misplaced_generation_parameters.items():
  3413. setattr(model_to_save.generation_config, param_name, param_value)
  3414. setattr(model_to_save.config, param_name, None)
  3415. model_to_save.config.save_pretrained(save_directory)
  3416. if self.can_generate():
  3417. model_to_save.generation_config.save_pretrained(save_directory)
  3418. if _hf_peft_config_loaded:
  3419. logger.info(
  3420. "Detected adapters on the model, saving the model in the PEFT format, only adapter weights will be saved."
  3421. )
  3422. state_dict = model_to_save.get_adapter_state_dict(state_dict=state_dict)
  3423. if save_peft_format:
  3424. logger.info(
  3425. "To match the expected format of the PEFT library, all keys of the state dict of adapters will be prepended with `base_model.model`."
  3426. )
  3427. peft_state_dict = {}
  3428. for key, value in state_dict.items():
  3429. peft_state_dict[f"base_model.model.{key}"] = value
  3430. state_dict = peft_state_dict
  3431. active_adapter = self.active_adapters()
  3432. if len(active_adapter) > 1:
  3433. raise ValueError(
  3434. "Multiple active adapters detected, saving multiple active adapters is not supported yet. You can save adapters separately one by one "
  3435. "by iteratively calling `model.set_adapter(adapter_name)` then `model.save_pretrained(...)`"
  3436. )
  3437. active_adapter = active_adapter[0]
  3438. current_peft_config = self.peft_config[active_adapter]
  3439. current_peft_config.save_pretrained(save_directory)
  3440. # for offloaded modules
  3441. module_map = {}
  3442. # Save the model
  3443. if state_dict is None:
  3444. # if any model parameters are offloaded, make module map
  3445. if (
  3446. hasattr(self, "hf_device_map")
  3447. and len(set(self.hf_device_map.values())) > 1
  3448. and ("cpu" in self.hf_device_map.values() or "disk" in self.hf_device_map.values())
  3449. ):
  3450. warnings.warn(
  3451. "Attempting to save a model with offloaded modules. Ensure that unallocated cpu memory exceeds the `shard_size` (5GB default)"
  3452. )
  3453. for name, module in model_to_save.named_modules():
  3454. if name == "":
  3455. continue
  3456. module_state_dict = module.state_dict()
  3457. for key in module_state_dict:
  3458. module_map[name + f".{key}"] = module
  3459. state_dict = model_to_save.state_dict()
  3460. if any(
  3461. allowed_name in class_name.__name__.lower()
  3462. for class_name in self.__class__.__mro__[:-1]
  3463. for allowed_name in VLMS
  3464. ):
  3465. reverse_key_mapping = {v: k for k, v in self._checkpoint_conversion_mapping.items()}
  3466. original_state_dict = {}
  3467. for key, value in state_dict.items():
  3468. for pattern, replacement in reverse_key_mapping.items():
  3469. replacement = replacement.lstrip("^") # strip off un-needed chars and patterns
  3470. replacement = re.sub(r"\(.*\)", "", replacement)
  3471. key, n_replace = re.subn(pattern, replacement, key)
  3472. # Early exit of the loop
  3473. if n_replace > 0:
  3474. break
  3475. original_state_dict[key] = value
  3476. state_dict = original_state_dict
  3477. # Translate state_dict from smp to hf if saving with smp >= 1.10
  3478. if IS_SAGEMAKER_MP_POST_1_10:
  3479. for smp_to_hf, _ in smp.state.module_manager.translate_functions:
  3480. state_dict = smp_to_hf(state_dict)
  3481. # Handle the case where some state_dict keys shouldn't be saved
  3482. if self._keys_to_ignore_on_save is not None:
  3483. for ignore_key in self._keys_to_ignore_on_save:
  3484. if ignore_key in state_dict:
  3485. del state_dict[ignore_key]
  3486. # Rename state_dict keys before saving to file. Do nothing unless overridden in a particular model.
  3487. # (initially introduced with TimmWrapperModel to remove prefix and make checkpoints compatible with timm)
  3488. state_dict = self._fix_state_dict_keys_on_save(state_dict)
  3489. # If model was sharded, we cannot properly determine sizes of tensors that `local_*` strategy was used,
  3490. # therefore we replace them with DTensors that are equivalently sharded
  3491. if self._tp_size is not None:
  3492. state_dict = replace_state_dict_local_with_dtensor(state_dict, self._tp_plan, self._device_mesh)
  3493. if safe_serialization:
  3494. # TODO: fix safe_serialization for tied weights
  3495. # Safetensors does not allow tensor aliasing.
  3496. # We're going to remove aliases before saving
  3497. ptrs = collections.defaultdict(list)
  3498. for name, tensor in state_dict.items():
  3499. if not isinstance(tensor, torch.Tensor):
  3500. # Sometimes in the state_dict we have non-tensor objects.
  3501. # e.g. in bitsandbytes we have some `str` objects in the state_dict
  3502. # In the non-tensor case, fall back to the pointer of the object itself
  3503. ptrs[id(tensor)].append(name)
  3504. elif tensor.device.type == "meta":
  3505. # In offloaded cases, there may be meta tensors in the state_dict.
  3506. # For these cases, key by the pointer of the original tensor object
  3507. # (state_dict tensors are detached and therefore no longer shared)
  3508. tensor = self.get_parameter(name)
  3509. ptrs[id(tensor)].append(name)
  3510. else:
  3511. ptrs[id_tensor_storage(tensor)].append(name)
  3512. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
  3513. # Recursively descend to find tied weight keys
  3514. _tied_weights_keys = _get_tied_weight_keys(self)
  3515. error_names = []
  3516. to_delete_names = set()
  3517. for names in shared_ptrs.values():
  3518. # Removing the keys which are declared as known duplicates on
  3519. # load. This allows to make sure the name which is kept is consistent.
  3520. if _tied_weights_keys is not None:
  3521. found = 0
  3522. for name in sorted(names):
  3523. matches_pattern = any(re.search(pat, name) for pat in _tied_weights_keys)
  3524. if matches_pattern and name in state_dict:
  3525. found += 1
  3526. if found < len(names):
  3527. to_delete_names.add(name)
  3528. # We are entering a place where the weights and the transformers configuration do NOT match.
  3529. shared_names, disjoint_names = _find_disjoint(shared_ptrs.values(), state_dict)
  3530. # Those are actually tensor sharing but disjoint from each other, we can safely clone them
  3531. # Reloaded won't have the same property, but it shouldn't matter in any meaningful way.
  3532. for name in disjoint_names:
  3533. state_dict[name] = state_dict[name].clone()
  3534. # When not all duplicates have been cleaned, still remove those keys, but put a clear warning.
  3535. # If the link between tensors was done at runtime then `from_pretrained` will not get
  3536. # the key back leading to random tensor. A proper warning will be shown
  3537. # during reload (if applicable), but since the file is not necessarily compatible with
  3538. # the config, better show a proper warning.
  3539. shared_names, identical_names = _find_identical(shared_names, state_dict)
  3540. # delete tensors that have identical storage
  3541. for inames in identical_names:
  3542. known = inames.intersection(to_delete_names)
  3543. for name in known:
  3544. del state_dict[name]
  3545. unknown = inames.difference(to_delete_names)
  3546. if len(unknown) > 1:
  3547. error_names.append(unknown)
  3548. if shared_names:
  3549. error_names.extend(shared_names)
  3550. if len(error_names) > 0:
  3551. raise RuntimeError(
  3552. f"The weights trying to be saved contained shared tensors {error_names} that are mismatching "
  3553. "the transformers base configuration. Try saving using `safe_serialization=False`, setting the "
  3554. "`_dynamic_tied_weights_keys` attribute for affected modules, or remove this tensor sharing.",
  3555. )
  3556. # Shard the model if it is too big.
  3557. if not _hf_peft_config_loaded:
  3558. weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
  3559. weights_name = _add_variant(weights_name, variant)
  3560. else:
  3561. weights_name = ADAPTER_SAFE_WEIGHTS_NAME if safe_serialization else ADAPTER_WEIGHTS_NAME
  3562. filename_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(".safetensors", "{suffix}.safetensors")
  3563. state_dict_split = split_torch_state_dict_into_shards(
  3564. state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
  3565. )
  3566. # Save index if sharded
  3567. index = None
  3568. if state_dict_split.is_sharded:
  3569. index = {
  3570. "metadata": {"total_parameters": self.num_parameters(), **state_dict_split.metadata},
  3571. "weight_map": state_dict_split.tensor_to_filename,
  3572. }
  3573. # Clean the folder from a previous save
  3574. for filename in os.listdir(save_directory):
  3575. full_filename = os.path.join(save_directory, filename)
  3576. # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
  3577. # in distributed settings to avoid race conditions.
  3578. weights_no_suffix = weights_name.replace(".bin", "").replace(".safetensors", "")
  3579. # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
  3580. filename_no_suffix = filename.replace(".bin", "").replace(".safetensors", "")
  3581. reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
  3582. if (
  3583. filename.startswith(weights_no_suffix)
  3584. and os.path.isfile(full_filename)
  3585. and filename not in state_dict_split.filename_to_tensors
  3586. and is_main_process
  3587. and reg.fullmatch(filename_no_suffix) is not None
  3588. ):
  3589. os.remove(full_filename)
  3590. # Save the model
  3591. filename_to_tensors = state_dict_split.filename_to_tensors.items()
  3592. if module_map:
  3593. filename_to_tensors = logging.tqdm(filename_to_tensors, desc="Saving checkpoint shards")
  3594. for shard_file, tensors in filename_to_tensors:
  3595. shard = {}
  3596. for tensor in tensors:
  3597. if _is_dtensor_available and isinstance(state_dict[tensor], DTensor):
  3598. full_tensor = state_dict[tensor].full_tensor()
  3599. # to get the correctly ordered tensor we need to repack if packed
  3600. if _get_parameter_tp_plan(tensor, self._tp_plan) == "local_packed_rowwise":
  3601. full_tensor = repack_weights(full_tensor, -1, self._tp_size, 2)
  3602. shard[tensor] = full_tensor.contiguous() # only do contiguous after it's permuted correctly
  3603. else:
  3604. shard[tensor] = state_dict[tensor].contiguous()
  3605. # delete reference, see https://github.com/huggingface/transformers/pull/34890
  3606. del state_dict[tensor]
  3607. # remake shard with onloaded parameters if necessary
  3608. if module_map:
  3609. if accelerate_version < version.parse("0.31"):
  3610. raise ImportError(
  3611. f"You need accelerate version to be greater or equal than 0.31 to save models with offloaded parameters. Detected version {accelerate_version}. "
  3612. f"Please upgrade accelerate with `pip install -U accelerate`"
  3613. )
  3614. # init state_dict for this shard
  3615. shard_state_dict = dict.fromkeys(shard, "")
  3616. for module_name in shard:
  3617. # note that get_state_dict_from_offload can update with meta tensors
  3618. # if both a parent module and its descendant are offloaded
  3619. tensor = shard_state_dict[module_name]
  3620. if tensor == "" or (isinstance(tensor, torch.Tensor) and tensor.device.type == "meta"):
  3621. # update state dict with onloaded parameters
  3622. module = module_map[module_name]
  3623. shard_state_dict = get_state_dict_from_offload(module, module_name, shard_state_dict)
  3624. # assign shard to be the completed state dict
  3625. shard = shard_state_dict
  3626. del shard_state_dict
  3627. gc.collect()
  3628. if safe_serialization:
  3629. # At some point we will need to deal better with save_function (used for TPU and other distributed
  3630. # joyfulness), but for now this enough.
  3631. safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
  3632. else:
  3633. save_function(shard, os.path.join(save_directory, shard_file))
  3634. del state_dict
  3635. if index is None:
  3636. path_to_weights = os.path.join(save_directory, weights_name)
  3637. logger.info(f"Model weights saved in {path_to_weights}")
  3638. else:
  3639. save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
  3640. save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
  3641. # Save the index as well
  3642. with open(save_index_file, "w", encoding="utf-8") as f:
  3643. content = json.dumps(index, indent=2, sort_keys=True) + "\n"
  3644. f.write(content)
  3645. logger.info(
  3646. f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
  3647. f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
  3648. f"index located at {save_index_file}."
  3649. )
  3650. if push_to_hub:
  3651. # Eventually create an empty model card
  3652. model_card = create_and_tag_model_card(
  3653. repo_id, self.model_tags, token=token, ignore_metadata_errors=ignore_metadata_errors
  3654. )
  3655. # Update model card if needed:
  3656. model_card.save(os.path.join(save_directory, "README.md"))
  3657. self._upload_modified_files(
  3658. save_directory,
  3659. repo_id,
  3660. files_timestamps,
  3661. commit_message=commit_message,
  3662. token=token,
  3663. create_pr=create_pr,
  3664. )
  3665. @wraps(PushToHubMixin.push_to_hub)
  3666. def push_to_hub(self, *args, **kwargs):
  3667. tags = self.model_tags if self.model_tags is not None else []
  3668. tags_kwargs = kwargs.get("tags", [])
  3669. if isinstance(tags_kwargs, str):
  3670. tags_kwargs = [tags_kwargs]
  3671. for tag in tags_kwargs:
  3672. if tag not in tags:
  3673. tags.append(tag)
  3674. if tags:
  3675. kwargs["tags"] = tags
  3676. return super().push_to_hub(*args, **kwargs)
  3677. def get_memory_footprint(self, return_buffers=True):
  3678. r"""
  3679. Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
  3680. Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
  3681. PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
  3682. Arguments:
  3683. return_buffers (`bool`, *optional*, defaults to `True`):
  3684. Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
  3685. are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
  3686. norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
  3687. """
  3688. mem = sum(param.nelement() * param.element_size() for param in self.parameters())
  3689. if return_buffers:
  3690. mem_bufs = sum(buf.nelement() * buf.element_size() for buf in self.buffers())
  3691. mem = mem + mem_bufs
  3692. return mem
  3693. @wraps(torch.nn.Module.cuda)
  3694. def cuda(self, *args, **kwargs):
  3695. if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
  3696. from hqq.core.quantize import HQQLinear
  3697. # Since HQQLinear stores some tensors in the 'meta' attribute,
  3698. # it's necessary to manually call the `cuda` method on HQQLinear layers.
  3699. super().cuda(*args, **kwargs)
  3700. for module in self.modules():
  3701. if isinstance(module, HQQLinear):
  3702. if len(args) > 0:
  3703. device = args[0]
  3704. else:
  3705. device = kwargs.get("device", "cuda")
  3706. module.cuda(device)
  3707. return self
  3708. # Checks if the model has been loaded in 4-bit or 8-bit with BNB
  3709. if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
  3710. if getattr(self, "is_loaded_in_8bit", False):
  3711. raise ValueError(
  3712. "Calling `cuda()` is not supported for `8-bit` quantized models. "
  3713. " Please use the model as it is, since the model has already been set to the correct devices."
  3714. )
  3715. elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
  3716. raise ValueError(
  3717. "Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
  3718. f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
  3719. )
  3720. return super().cuda(*args, **kwargs)
  3721. @wraps(torch.nn.Module.to)
  3722. def to(self, *args, **kwargs):
  3723. # For BNB/GPTQ models, we prevent users from casting the model to another dtype to restrict unwanted behaviours.
  3724. # the correct API should be to load the model with the desired dtype directly through `from_pretrained`.
  3725. dtype_present_in_args = "dtype" in kwargs
  3726. if not dtype_present_in_args:
  3727. for arg in args:
  3728. if isinstance(arg, torch.dtype):
  3729. dtype_present_in_args = True
  3730. break
  3731. if getattr(self, "quantization_method", None) == QuantizationMethod.HQQ:
  3732. from hqq.core.quantize import HQQLinear
  3733. # Since HQQLinear stores some tensors in the 'meta' attribute, we must
  3734. # explicitly move the parameters to the target device for each HQQLinear layer after `to`.
  3735. super().to(*args, **kwargs)
  3736. for module in self.modules():
  3737. if isinstance(module, HQQLinear):
  3738. if "device" in kwargs:
  3739. device = kwargs["device"]
  3740. else:
  3741. device = args[0]
  3742. if "dtype" in kwargs:
  3743. dtype = kwargs["dtype"]
  3744. elif dtype_present_in_args:
  3745. dtype = arg
  3746. else:
  3747. dtype = None
  3748. # Due to the current messy implementation of HQQLinear, updating `compute_dtype`
  3749. # followed by calling the `cuda` method achieves the intended behavior of `to`,
  3750. # even when the target device is CPU.
  3751. if dtype is not None:
  3752. module.compute_dtype = dtype
  3753. module.cuda(device)
  3754. return self
  3755. if dtype_present_in_args and getattr(self, "quantization_method", None) == QuantizationMethod.QUARK:
  3756. raise ValueError("Casting a Quark quantized model to a new `dtype` is not supported.")
  3757. # Checks if the model has been loaded in 4-bit or 8-bit with BNB
  3758. if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
  3759. if dtype_present_in_args:
  3760. raise ValueError(
  3761. "You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
  3762. " desired `dtype` by passing the correct `dtype` argument."
  3763. )
  3764. if getattr(self, "is_loaded_in_8bit", False):
  3765. raise ValueError(
  3766. "`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
  3767. " model has already been set to the correct devices and casted to the correct `dtype`."
  3768. )
  3769. elif version.parse(importlib.metadata.version("bitsandbytes")) < version.parse("0.43.2"):
  3770. raise ValueError(
  3771. "Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
  3772. f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
  3773. )
  3774. elif getattr(self, "quantization_method", None) == QuantizationMethod.GPTQ:
  3775. if dtype_present_in_args:
  3776. raise ValueError(
  3777. "You cannot cast a GPTQ model in a new `dtype`. Make sure to load the model using `from_pretrained` using the desired"
  3778. " `dtype` by passing the correct `dtype` argument."
  3779. )
  3780. return super().to(*args, **kwargs)
  3781. def half(self, *args):
  3782. # Checks if the model is quantized
  3783. if getattr(self, "is_quantized", False):
  3784. raise ValueError(
  3785. "`.half()` is not supported for quantized model. Please use the model as it is, since the"
  3786. " model has already been casted to the correct `dtype`."
  3787. )
  3788. else:
  3789. return super().half(*args)
  3790. def float(self, *args):
  3791. # Checks if the model is quantized
  3792. if getattr(self, "is_quantized", False):
  3793. raise ValueError(
  3794. "`.float()` is not supported for quantized model. Please use the model as it is, since the"
  3795. " model has already been casted to the correct `dtype`."
  3796. )
  3797. else:
  3798. return super().float(*args)
  3799. @classmethod
  3800. def get_init_context(cls, is_quantized: bool, _is_ds_init_called: bool):
  3801. if is_deepspeed_zero3_enabled():
  3802. import deepspeed
  3803. init_contexts = [no_init_weights()]
  3804. # We cannot initialize the model on meta device with deepspeed when not quantized
  3805. if not is_quantized and not _is_ds_init_called:
  3806. logger.info("Detected DeepSpeed ZeRO-3: activating zero.init() for this model")
  3807. init_contexts.extend([deepspeed.zero.Init(config_dict_or_path=deepspeed_config()), set_zero3_state()])
  3808. elif is_quantized:
  3809. init_contexts.extend([init_empty_weights(), set_quantized_state()])
  3810. else:
  3811. init_contexts = [no_init_weights(), init_empty_weights()]
  3812. return init_contexts
  3813. @classmethod
  3814. @restore_default_dtype
  3815. def from_pretrained(
  3816. cls: type[SpecificPreTrainedModelType],
  3817. pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
  3818. *model_args,
  3819. config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
  3820. cache_dir: Optional[Union[str, os.PathLike]] = None,
  3821. ignore_mismatched_sizes: bool = False,
  3822. force_download: bool = False,
  3823. local_files_only: bool = False,
  3824. token: Optional[Union[str, bool]] = None,
  3825. revision: str = "main",
  3826. use_safetensors: Optional[bool] = None,
  3827. weights_only: bool = True,
  3828. **kwargs,
  3829. ) -> SpecificPreTrainedModelType:
  3830. r"""
  3831. Instantiate a pretrained pytorch model from a pre-trained model configuration.
  3832. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
  3833. the model, you should first set it back in training mode with `model.train()`.
  3834. The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
  3835. pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
  3836. task.
  3837. The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
  3838. weights are discarded.
  3839. Parameters:
  3840. pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
  3841. Can be either:
  3842. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  3843. - A path to a *directory* containing model weights saved using
  3844. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  3845. - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
  3846. this case, `from_tf` should be set to `True` and a configuration object should be provided as
  3847. `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
  3848. PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
  3849. - A path or url to a model folder containing a *flax checkpoint file* in *.msgpack* format (e.g,
  3850. `./flax_model/` containing `flax_model.msgpack`). In this case, `from_flax` should be set to
  3851. `True`.
  3852. - `None` if you are both providing the configuration and state dictionary (resp. with keyword
  3853. arguments `config` and `state_dict`).
  3854. model_args (sequence of positional arguments, *optional*):
  3855. All remaining positional arguments will be passed to the underlying model's `__init__` method.
  3856. config (`Union[PretrainedConfig, str, os.PathLike]`, *optional*):
  3857. Can be either:
  3858. - an instance of a class derived from [`PretrainedConfig`],
  3859. - a string or path valid as input to [`~PretrainedConfig.from_pretrained`].
  3860. Configuration for the model to use instead of an automatically loaded configuration. Configuration can
  3861. be automatically loaded when:
  3862. - The model is a model provided by the library (loaded with the *model id* string of a pretrained
  3863. model).
  3864. - The model was saved using [`~PreTrainedModel.save_pretrained`] and is reloaded by supplying the
  3865. save directory.
  3866. - The model is loaded by supplying a local directory as `pretrained_model_name_or_path` and a
  3867. configuration JSON file named *config.json* is found in the directory.
  3868. state_dict (`dict[str, torch.Tensor]`, *optional*):
  3869. A state dictionary to use instead of a state dictionary loaded from saved weights file.
  3870. This option can be used if you want to create a model from a pretrained configuration but load your own
  3871. weights. In this case though, you should check if using [`~PreTrainedModel.save_pretrained`] and
  3872. [`~PreTrainedModel.from_pretrained`] is not a simpler option.
  3873. cache_dir (`Union[str, os.PathLike]`, *optional*):
  3874. Path to a directory in which a downloaded pretrained model configuration should be cached if the
  3875. standard cache should not be used.
  3876. from_tf (`bool`, *optional*, defaults to `False`):
  3877. Load the model weights from a TensorFlow checkpoint save file (see docstring of
  3878. `pretrained_model_name_or_path` argument).
  3879. from_flax (`bool`, *optional*, defaults to `False`):
  3880. Load the model weights from a Flax checkpoint save file (see docstring of
  3881. `pretrained_model_name_or_path` argument).
  3882. ignore_mismatched_sizes (`bool`, *optional*, defaults to `False`):
  3883. Whether or not to raise an error if some of the weights from the checkpoint do not have the same size
  3884. as the weights of the model (if for instance, you are instantiating a model with 10 labels from a
  3885. checkpoint with 3 labels).
  3886. force_download (`bool`, *optional*, defaults to `False`):
  3887. Whether or not to force the (re-)download of the model weights and configuration files, overriding the
  3888. cached versions if they exist.
  3889. proxies (`dict[str, str]`, *optional*):
  3890. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  3891. 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
  3892. output_loading_info(`bool`, *optional*, defaults to `False`):
  3893. Whether ot not to also return a dictionary containing missing keys, unexpected keys and error messages.
  3894. local_files_only(`bool`, *optional*, defaults to `False`):
  3895. Whether or not to only look at local files (i.e., do not try to download the model).
  3896. token (`str` or `bool`, *optional*):
  3897. The token to use as HTTP bearer authorization for remote files. If `True`, or not specified, will use
  3898. the token generated when running `hf auth login` (stored in `~/.huggingface`).
  3899. revision (`str`, *optional*, defaults to `"main"`):
  3900. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  3901. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  3902. identifier allowed by git.
  3903. <Tip>
  3904. To test a pull request you made on the Hub, you can pass `revision="refs/pr/<pr_number>"`.
  3905. </Tip>
  3906. attn_implementation (`str`, *optional*):
  3907. The attention implementation to use in the model (if relevant). Can be any of `"eager"` (manual implementation of the attention), `"sdpa"` (using [`F.scaled_dot_product_attention`](https://pytorch.org/docs/master/generated/torch.nn.functional.scaled_dot_product_attention.html)), `"flash_attention_2"` (using [Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention)), or `"flash_attention_3"` (using [Dao-AILab/flash-attention/hopper](https://github.com/Dao-AILab/flash-attention/tree/main/hopper)). By default, if available, SDPA will be used for torch>=2.1.1. The default is otherwise the manual `"eager"` implementation.
  3908. Accept HF kernel references in the form:
  3909. <namespace>/<repo_name>[@<revision>][:<kernel_name>]
  3910. - <namespace> and <repo_name> are any non-"/" and non-":" sequences.
  3911. - "@<revision>" is optional (branch, tag, or commit-ish), e.g. "@main", "@v1.2.0", "@abc123".
  3912. - ":<kernel_name>" is optional and selects a function inside the kernel repo.
  3913. - Both options can appear together and in this order only: @revision first, then :kernel_name.
  3914. - We intentionally allow a leading "<wrapper>|" prefix (e.g., "flash|...") because the code
  3915. strips it before loading; '|' is not excluded in the character classes here.
  3916. Examples that match:
  3917. "org/model"
  3918. "org/model@main"
  3919. "org/model:custom_kernel"
  3920. "org/model@v1.2.3:custom_kernel"
  3921. > Parameters for big model inference
  3922. dtype (`str` or `torch.dtype`, *optional*):
  3923. Override the default `torch_dtype` and load the model under a specific `dtype`. The different options
  3924. are:
  3925. 1. `torch.float16` or `torch.bfloat16` or `torch.float`: load in a specified
  3926. `dtype`, ignoring the model's `config.dtype` if one exists. If not specified
  3927. - the model will get loaded in `torch.float` (fp32).
  3928. 2. `"auto"` - A `dtype` or `torch_dtype` entry in the `config.json` file of the model will be
  3929. attempted to be used. If this entry isn't found then next check the `dtype` of the first weight in
  3930. the checkpoint that's of a floating point type and use that as `dtype`. This will load the model
  3931. using the `dtype` it was saved in at the end of the training. It can't be used as an indicator of how
  3932. the model was trained. Since it could be trained in one of half precision dtypes, but saved in fp32.
  3933. 3. A string that is a valid `torch.dtype`. E.g. "float32" loads the model in `torch.float32`, "float16" loads in `torch.float16` etc.
  3934. <Tip>
  3935. For some models the `dtype` they were trained in is unknown - you may try to check the model's paper or
  3936. reach out to the authors and ask them to add this information to the model's card and to insert the
  3937. `dtype` or `torch_dtype` entry in `config.json` on the hub.
  3938. </Tip>
  3939. device_map (`str` or `dict[str, Union[int, str, torch.device]]` or `int` or `torch.device`, *optional*):
  3940. A map that specifies where each submodule should go. It doesn't need to be refined to each
  3941. parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
  3942. same device. If we only pass the device (*e.g.*, `"cpu"`, `"cuda:1"`, `"mps"`, or a GPU ordinal rank
  3943. like `1`) on which the model will be allocated, the device map will map the entire model to this
  3944. device. Passing `device_map = 0` means put the whole model on GPU 0.
  3945. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
  3946. more information about each option see [designing a device
  3947. map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
  3948. max_memory (`Dict`, *optional*):
  3949. A dictionary device identifier to maximum memory if using `device_map`. Will default to the maximum memory available for each
  3950. GPU and the available CPU RAM if unset.
  3951. tp_plan (`str`, *optional*):
  3952. A torch tensor parallel plan, see [here](https://pytorch.org/tutorials/intermediate/TP_tutorial.html). Currently, it only accepts
  3953. `tp_plan="auto"` to use predefined plan based on the model. Note that if you use it, you should launch your script accordingly with
  3954. `torchrun [args] script.py`. This will be much faster than using a `device_map`, but has limitations.
  3955. tp_size (`str`, *optional*):
  3956. A torch tensor parallel degree. If not provided would default to world size.
  3957. device_mesh (`torch.distributed.DeviceMesh`, *optional*):
  3958. A torch device mesh. If not provided would default to world size. Used only for tensor parallel for now.
  3959. If provided, it has to contain dimension named `"tp"` in case it's > 1 dimensional, this dimension will be used for tensor parallelism
  3960. offload_folder (`str` or `os.PathLike`, *optional*):
  3961. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
  3962. offload_buffers (`bool`, *optional*):
  3963. Whether or not to offload the buffers with the model parameters.
  3964. quantization_config (`Union[QuantizationConfigMixin,Dict]`, *optional*):
  3965. A dictionary of configuration parameters or a QuantizationConfigMixin object for quantization (e.g
  3966. bitsandbytes, gptq). There may be other quantization-related kwargs, including `load_in_4bit` and
  3967. `load_in_8bit`, which are parsed by QuantizationConfigParser. Supported only for bitsandbytes
  3968. quantizations and not preferred. consider inserting all such arguments into quantization_config
  3969. instead.
  3970. subfolder (`str`, *optional*, defaults to `""`):
  3971. In case the relevant files are located inside a subfolder of the model repo on huggingface.co, you can
  3972. specify the folder name here.
  3973. variant (`str`, *optional*):
  3974. If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
  3975. ignored when using `from_tf` or `from_flax`.
  3976. use_safetensors (`bool`, *optional*, defaults to `None`):
  3977. Whether or not to use `safetensors` checkpoints. Defaults to `None`. If not specified and `safetensors`
  3978. is not installed, it will be set to `False`.
  3979. weights_only (`bool`, *optional*, defaults to `True`):
  3980. Indicates whether unpickler should be restricted to loading only tensors, primitive types,
  3981. dictionaries and any types added via torch.serialization.add_safe_globals().
  3982. When set to False, we can load wrapper tensor subclass weights.
  3983. key_mapping (`dict[str, str], *optional*):
  3984. A potential mapping of the weight names if using a model on the Hub which is compatible to a Transformers
  3985. architecture, but was not converted accordingly.
  3986. kwargs (remaining dictionary of keyword arguments, *optional*):
  3987. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  3988. `output_attentions=True`). Behaves differently depending on whether a `config` is provided or
  3989. automatically loaded:
  3990. - If a configuration is provided with `config`, `**kwargs` will be directly passed to the
  3991. underlying model's `__init__` method (we assume all relevant updates to the configuration have
  3992. already been done)
  3993. - If a configuration is not provided, `kwargs` will be first passed to the configuration class
  3994. initialization function ([`~PretrainedConfig.from_pretrained`]). Each key of `kwargs` that
  3995. corresponds to a configuration attribute will be used to override said attribute with the
  3996. supplied `kwargs` value. Remaining keys that do not correspond to any configuration attribute
  3997. will be passed to the underlying model's `__init__` function.
  3998. <Tip>
  3999. Activate the special ["offline-mode"](https://huggingface.co/transformers/installation.html#offline-mode) to
  4000. use this method in a firewalled environment.
  4001. </Tip>
  4002. Examples:
  4003. ```python
  4004. >>> from transformers import BertConfig, BertModel
  4005. >>> # Download model and configuration from huggingface.co and cache.
  4006. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased")
  4007. >>> # Model was saved using *save_pretrained('./test/saved_model/')* (for example purposes, not runnable).
  4008. >>> model = BertModel.from_pretrained("./test/saved_model/")
  4009. >>> # Update configuration during loading.
  4010. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", output_attentions=True)
  4011. >>> assert model.config.output_attentions == True
  4012. >>> # Loading from a TF checkpoint file instead of a PyTorch model (slower, for example purposes, not runnable).
  4013. >>> config = BertConfig.from_json_file("./tf_model/my_tf_model_config.json")
  4014. >>> model = BertModel.from_pretrained("./tf_model/my_tf_checkpoint.ckpt.index", from_tf=True, config=config)
  4015. >>> # Loading from a Flax checkpoint file instead of a PyTorch model (slower)
  4016. >>> model = BertModel.from_pretrained("google-bert/bert-base-uncased", from_flax=True)
  4017. ```
  4018. """
  4019. state_dict = kwargs.pop("state_dict", None)
  4020. from_tf = kwargs.pop("from_tf", False)
  4021. from_flax = kwargs.pop("from_flax", False)
  4022. proxies = kwargs.pop("proxies", None)
  4023. output_loading_info = kwargs.pop("output_loading_info", False)
  4024. use_auth_token = kwargs.pop("use_auth_token", None)
  4025. from_pipeline = kwargs.pop("_from_pipeline", None)
  4026. from_auto_class = kwargs.pop("_from_auto", False)
  4027. dtype = kwargs.pop("dtype", None)
  4028. torch_dtype = kwargs.pop("torch_dtype", None) # kept for BC
  4029. device_map = kwargs.pop("device_map", None)
  4030. max_memory = kwargs.pop("max_memory", None)
  4031. offload_folder = kwargs.pop("offload_folder", None)
  4032. offload_buffers = kwargs.pop("offload_buffers", False)
  4033. load_in_8bit = kwargs.pop("load_in_8bit", False)
  4034. load_in_4bit = kwargs.pop("load_in_4bit", False)
  4035. quantization_config = kwargs.pop("quantization_config", None)
  4036. subfolder = kwargs.pop("subfolder", "")
  4037. commit_hash = kwargs.pop("_commit_hash", None)
  4038. variant = kwargs.pop("variant", None)
  4039. adapter_kwargs = kwargs.pop("adapter_kwargs", {})
  4040. adapter_name = kwargs.pop("adapter_name", "default")
  4041. generation_config = kwargs.pop("generation_config", None)
  4042. gguf_file = kwargs.pop("gguf_file", None)
  4043. tp_plan = kwargs.pop("tp_plan", None)
  4044. tp_size = kwargs.pop("tp_size", None)
  4045. distributed_config: DistributedConfig = kwargs.pop("distributed_config", None)
  4046. device_mesh = kwargs.pop("device_mesh", None)
  4047. trust_remote_code = kwargs.pop("trust_remote_code", None)
  4048. use_kernels = kwargs.pop("use_kernels", False)
  4049. key_mapping = kwargs.pop("key_mapping", None)
  4050. # Load models with hardcoded key mapping on class for VLMs only, to keep BC and standardize model
  4051. if key_mapping is None and any(
  4052. allowed_name in class_name.__name__.lower() for class_name in cls.__mro__[:-1] for allowed_name in VLMS
  4053. ):
  4054. key_mapping = cls._checkpoint_conversion_mapping
  4055. if distributed_config is not None:
  4056. tp_plan = "auto"
  4057. # Not used anymore -- remove them from the kwargs
  4058. _ = kwargs.pop("resume_download", None)
  4059. _ = kwargs.pop("mirror", None)
  4060. _ = kwargs.pop("_fast_init", True)
  4061. _ = kwargs.pop("low_cpu_mem_usage", None)
  4062. _ = kwargs.pop("offload_state_dict", None)
  4063. # For BC on torch_dtype argument
  4064. if torch_dtype is not None:
  4065. logger.warning_once("`torch_dtype` is deprecated! Use `dtype` instead!")
  4066. # If both kwargs are provided, use `dtype`
  4067. dtype = dtype if dtype is not None else torch_dtype
  4068. if state_dict is not None and (pretrained_model_name_or_path is not None or gguf_file is not None):
  4069. raise ValueError(
  4070. "`state_dict` cannot be passed together with a model name or a `gguf_file`. Use one of the two loading strategies."
  4071. )
  4072. if tp_size is not None and tp_plan is None:
  4073. raise ValueError("tp_plan has to be set when tp_size is passed.")
  4074. if tp_plan is not None and tp_plan != "auto":
  4075. # TODO: we can relax this check when we support taking tp_plan from a json file, for example.
  4076. raise ValueError(f"tp_plan supports 'auto' only for now but got {tp_plan}.")
  4077. if tp_plan is not None and device_map is not None:
  4078. raise ValueError(
  4079. "`tp_plan` and `device_map` are mutually exclusive. Choose either one for parallelization."
  4080. )
  4081. if device_map == "auto" and int(os.environ.get("WORLD_SIZE", "0")):
  4082. logger.info(
  4083. "You've set device_map=`auto` while triggering a distributed run with torchrun. This might lead to unexpected behavior. "
  4084. "If your plan is to load the model on each device, you should set device_map={"
  4085. ": PartialState().process_index} where PartialState comes from accelerate library"
  4086. )
  4087. # We need to correctly dispatch the model on the current process device. The easiest way for this is to use a simple
  4088. # `device_map` pointing to the correct device
  4089. if tp_plan is not None:
  4090. if device_mesh is None:
  4091. tp_plan, device_map, device_mesh, tp_size = initialize_tensor_parallelism(tp_plan, tp_size=tp_size)
  4092. else:
  4093. if device_mesh.ndim > 1:
  4094. if "tp" not in device_mesh.mesh_dim_names:
  4095. raise ValueError(
  4096. "When using `tp_plan` and n-d `device_mesh`, it must contain a 'tp' dimension. "
  4097. "Please provide a valid `device_mesh`."
  4098. )
  4099. device_mesh = device_mesh["tp"]
  4100. tp_size = device_mesh.size()
  4101. device_map = torch.device(f"{device_mesh.device_type}:{int(os.environ['LOCAL_RANK'])}")
  4102. if tp_size is None:
  4103. tp_size = torch.distributed.get_world_size()
  4104. if use_auth_token is not None:
  4105. warnings.warn(
  4106. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  4107. FutureWarning,
  4108. )
  4109. if token is not None:
  4110. raise ValueError(
  4111. "`token` and `use_auth_token` are both specified. Please set only the argument `token`."
  4112. )
  4113. token = use_auth_token
  4114. if token is not None and adapter_kwargs is not None and "token" not in adapter_kwargs:
  4115. adapter_kwargs["token"] = token
  4116. if gguf_file is not None and not is_accelerate_available():
  4117. raise ValueError("accelerate is required when loading a GGUF file `pip install accelerate`.")
  4118. if commit_hash is None:
  4119. if not isinstance(config, PretrainedConfig):
  4120. # We make a call to the config file first (which may be absent) to get the commit hash as soon as possible
  4121. resolved_config_file = cached_file(
  4122. pretrained_model_name_or_path,
  4123. CONFIG_NAME,
  4124. cache_dir=cache_dir,
  4125. force_download=force_download,
  4126. proxies=proxies,
  4127. local_files_only=local_files_only,
  4128. token=token,
  4129. revision=revision,
  4130. subfolder=subfolder,
  4131. _raise_exceptions_for_gated_repo=False,
  4132. _raise_exceptions_for_missing_entries=False,
  4133. _raise_exceptions_for_connection_errors=False,
  4134. )
  4135. commit_hash = extract_commit_hash(resolved_config_file, commit_hash)
  4136. else:
  4137. commit_hash = getattr(config, "_commit_hash", None)
  4138. if is_peft_available():
  4139. _adapter_model_path = adapter_kwargs.pop("_adapter_model_path", None)
  4140. if _adapter_model_path is None:
  4141. _adapter_model_path = find_adapter_config_file(
  4142. pretrained_model_name_or_path,
  4143. cache_dir=cache_dir,
  4144. force_download=force_download,
  4145. proxies=proxies,
  4146. local_files_only=local_files_only,
  4147. _commit_hash=commit_hash,
  4148. **adapter_kwargs,
  4149. )
  4150. if _adapter_model_path is not None and os.path.isfile(_adapter_model_path):
  4151. with open(_adapter_model_path, "r", encoding="utf-8") as f:
  4152. _adapter_model_path = pretrained_model_name_or_path
  4153. pretrained_model_name_or_path = json.load(f)["base_model_name_or_path"]
  4154. else:
  4155. _adapter_model_path = None
  4156. # Potentially detect context manager or global device, and use it (only if no device_map was provided)
  4157. if device_map is None and not is_deepspeed_zero3_enabled():
  4158. device_in_context = get_torch_context_manager_or_global_device()
  4159. if device_in_context == torch.device("meta"):
  4160. raise RuntimeError(
  4161. "You are using `from_pretrained` with a meta device context manager or `torch.set_default_device('meta')`.\n"
  4162. "This is an anti-pattern as `from_pretrained` wants to load existing weights.\nIf you want to initialize an "
  4163. "empty model on the meta device, use the context manager or global device with `from_config`, or `ModelClass(config)`"
  4164. )
  4165. device_map = device_in_context
  4166. # change device_map into a map if we passed an int, a str or a torch.device
  4167. if isinstance(device_map, torch.device):
  4168. device_map = {"": device_map}
  4169. elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
  4170. try:
  4171. device_map = {"": torch.device(device_map)}
  4172. except RuntimeError:
  4173. raise ValueError(
  4174. "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
  4175. f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
  4176. )
  4177. elif isinstance(device_map, int):
  4178. if device_map < 0:
  4179. raise ValueError(
  4180. "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
  4181. )
  4182. else:
  4183. device_map = {"": device_map}
  4184. if device_map is not None:
  4185. if is_deepspeed_zero3_enabled():
  4186. raise ValueError("DeepSpeed Zero-3 is not compatible with passing a `device_map`.")
  4187. if not is_accelerate_available():
  4188. raise ValueError(
  4189. "Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
  4190. "requires `accelerate`. You can install it with `pip install accelerate`"
  4191. )
  4192. # handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
  4193. if load_in_4bit or load_in_8bit:
  4194. if quantization_config is not None:
  4195. raise ValueError(
  4196. "You can't pass `load_in_4bit`or `load_in_8bit` as a kwarg when passing "
  4197. "`quantization_config` argument at the same time."
  4198. )
  4199. # preparing BitsAndBytesConfig from kwargs
  4200. config_dict = {k: v for k, v in kwargs.items() if k in inspect.signature(BitsAndBytesConfig).parameters}
  4201. config_dict = {**config_dict, "load_in_4bit": load_in_4bit, "load_in_8bit": load_in_8bit}
  4202. quantization_config, kwargs = BitsAndBytesConfig.from_dict(
  4203. config_dict=config_dict, return_unused_kwargs=True, **kwargs
  4204. )
  4205. logger.warning(
  4206. "The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. "
  4207. "Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead."
  4208. )
  4209. from_pt = not (from_tf | from_flax)
  4210. user_agent = {"file_type": "model", "framework": "pytorch", "from_auto_class": from_auto_class}
  4211. if from_pipeline is not None:
  4212. user_agent["using_pipeline"] = from_pipeline
  4213. if is_offline_mode() and not local_files_only:
  4214. logger.info("Offline mode: forcing local_files_only=True")
  4215. local_files_only = True
  4216. # Load config if we don't provide a configuration
  4217. if not isinstance(config, PretrainedConfig):
  4218. config_path = config if config is not None else pretrained_model_name_or_path
  4219. config, model_kwargs = cls.config_class.from_pretrained(
  4220. config_path,
  4221. cache_dir=cache_dir,
  4222. return_unused_kwargs=True,
  4223. force_download=force_download,
  4224. proxies=proxies,
  4225. local_files_only=local_files_only,
  4226. token=token,
  4227. revision=revision,
  4228. subfolder=subfolder,
  4229. gguf_file=gguf_file,
  4230. _from_auto=from_auto_class,
  4231. _from_pipeline=from_pipeline,
  4232. **kwargs,
  4233. )
  4234. if "gguf_file" in model_kwargs:
  4235. model_kwargs.pop("gguf_file")
  4236. else:
  4237. config = copy.deepcopy(config)
  4238. model_kwargs = kwargs
  4239. # Because some composite configs call super().__init__ before instantiating the sub-configs, we need this call
  4240. # to correctly redispatch recursively if the kwarg is provided
  4241. if "attn_implementation" in kwargs:
  4242. config._attn_implementation = kwargs.pop("attn_implementation")
  4243. transformers_explicit_filename = getattr(config, "transformers_weights", None)
  4244. if transformers_explicit_filename is not None:
  4245. if not transformers_explicit_filename.endswith(
  4246. ".safetensors"
  4247. ) and not transformers_explicit_filename.endswith(".safetensors.index.json"):
  4248. raise ValueError(
  4249. "The transformers file in the config seems to be incorrect: it is neither a safetensors file "
  4250. "(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
  4251. f"{transformers_explicit_filename}"
  4252. )
  4253. hf_quantizer, config, dtype, device_map = get_hf_quantizer(
  4254. config, quantization_config, dtype, from_tf, from_flax, device_map, weights_only, user_agent
  4255. )
  4256. if gguf_file is not None and hf_quantizer is not None:
  4257. raise ValueError(
  4258. "You cannot combine Quantization and loading a model from a GGUF file, try again by making sure you did not passed a `quantization_config` or that you did not load a quantized model from the Hub."
  4259. )
  4260. if (
  4261. gguf_file
  4262. and device_map is not None
  4263. and ((isinstance(device_map, dict) and "disk" in device_map.values()) or "disk" in device_map)
  4264. ):
  4265. raise RuntimeError(
  4266. "One or more modules is configured to be mapped to disk. Disk offload is not supported for models "
  4267. "loaded from GGUF files."
  4268. )
  4269. checkpoint_files, sharded_metadata = _get_resolved_checkpoint_files(
  4270. pretrained_model_name_or_path=pretrained_model_name_or_path,
  4271. subfolder=subfolder,
  4272. variant=variant,
  4273. gguf_file=gguf_file,
  4274. from_tf=from_tf,
  4275. from_flax=from_flax,
  4276. use_safetensors=use_safetensors,
  4277. cache_dir=cache_dir,
  4278. force_download=force_download,
  4279. proxies=proxies,
  4280. local_files_only=local_files_only,
  4281. token=token,
  4282. user_agent=user_agent,
  4283. revision=revision,
  4284. commit_hash=commit_hash,
  4285. is_remote_code=cls._auto_class is not None,
  4286. transformers_explicit_filename=transformers_explicit_filename,
  4287. )
  4288. is_sharded = sharded_metadata is not None
  4289. is_quantized = hf_quantizer is not None
  4290. is_from_file = pretrained_model_name_or_path is not None or gguf_file is not None
  4291. if is_from_file and not is_sharded and checkpoint_files[0].endswith(".safetensors"):
  4292. with safe_open(checkpoint_files[0], framework="pt") as f:
  4293. metadata = f.metadata()
  4294. if metadata is None:
  4295. # Assume it's a pytorch checkpoint (introduced for timm checkpoints)
  4296. pass
  4297. elif metadata.get("format") == "pt":
  4298. pass
  4299. elif metadata.get("format") == "tf":
  4300. from_tf = True
  4301. logger.info("A TensorFlow safetensors file is being loaded in a PyTorch model.")
  4302. elif metadata.get("format") == "flax":
  4303. from_flax = True
  4304. logger.info("A Flax safetensors file is being loaded in a PyTorch model.")
  4305. elif metadata.get("format") == "mlx":
  4306. # This is a mlx file, we assume weights are compatible with pt
  4307. pass
  4308. else:
  4309. raise ValueError(
  4310. f"Incompatible safetensors file. File metadata is not ['pt', 'tf', 'flax', 'mlx'] but {metadata.get('format')}"
  4311. )
  4312. from_pt = not (from_tf | from_flax)
  4313. if from_pt:
  4314. if gguf_file:
  4315. from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
  4316. # we need a dummy model to get the state_dict - for this reason, we keep the state_dict as if it was
  4317. # passed directly as a kwarg from now on
  4318. with torch.device("meta"):
  4319. dummy_model = cls(config)
  4320. state_dict = load_gguf_checkpoint(checkpoint_files[0], return_tensors=True, model_to_load=dummy_model)[
  4321. "tensors"
  4322. ]
  4323. # Find the correct dtype based on current state
  4324. config, dtype, dtype_orig = _get_dtype(
  4325. cls, dtype, checkpoint_files, config, sharded_metadata, state_dict, weights_only
  4326. )
  4327. config.name_or_path = pretrained_model_name_or_path
  4328. model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
  4329. config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
  4330. with ContextManagers(model_init_context):
  4331. # Let's make sure we don't run the init function of buffer modules
  4332. model = cls(config, *model_args, **model_kwargs)
  4333. # Make sure to tie the weights correctly
  4334. model.tie_weights()
  4335. # make sure we use the model's config since the __init__ call might have copied it
  4336. config = model.config
  4337. # Find fp32 modules if needed
  4338. keep_in_fp32_modules = []
  4339. # The _keep_in_fp32_modules flag is only used to avoid bf16 -> fp16 casting precision issues. It was introduced
  4340. # in case of force loading a model that should stay bf16 in fp16 (which includes a few quantizers as this is a pre-processing
  4341. # step for e.g. bitsandbytes). See https://github.com/huggingface/transformers/issues/20287 for details.
  4342. if model._keep_in_fp32_modules is not None and (
  4343. dtype == torch.float16 or getattr(hf_quantizer, "use_keep_in_fp32_modules", False)
  4344. ):
  4345. keep_in_fp32_modules.extend(model._keep_in_fp32_modules)
  4346. if model._keep_in_fp32_modules_strict is not None and (dtype == torch.float16 or dtype == torch.bfloat16):
  4347. keep_in_fp32_modules.extend(model._keep_in_fp32_modules_strict)
  4348. keep_in_fp32_regex = None
  4349. if keep_in_fp32_modules:
  4350. # We need to match exact layers, so we add either `.` on each side, or start/end of string
  4351. keep_in_fp32_regex = re.compile("|".join([rf"((^|\.){module}($|\.))" for module in keep_in_fp32_modules]))
  4352. if hf_quantizer is not None:
  4353. hf_quantizer.preprocess_model(
  4354. model=model,
  4355. device_map=device_map,
  4356. keep_in_fp32_modules=model._keep_in_fp32_modules,
  4357. config=config,
  4358. use_kernels=use_kernels,
  4359. )
  4360. # We store the original dtype for quantized models as we cannot easily retrieve it
  4361. # once the weights have been quantized
  4362. # Note that once you have loaded a quantized model, you can't change its dtype so this will
  4363. # remain a single source of truth
  4364. original_dtype = dtype if dtype is not None else torch.get_default_dtype()
  4365. def _assign_original_dtype(module):
  4366. for child in module.children():
  4367. if isinstance(child, PreTrainedModel):
  4368. child.config._pre_quantization_dtype = original_dtype
  4369. _assign_original_dtype(child)
  4370. config._pre_quantization_dtype = original_dtype
  4371. _assign_original_dtype(model)
  4372. # Torchao needs access to all metadata later
  4373. if hf_quantizer.quantization_config.quant_method == QuantizationMethod.TORCHAO:
  4374. hf_quantizer.set_metadata(checkpoint_files)
  4375. if _torch_distributed_available and device_mesh is not None:
  4376. model = distribute_model(model, distributed_config, device_mesh, tp_size)
  4377. # Prepare the full device map
  4378. if device_map is not None:
  4379. device_map = _get_device_map(model, device_map, max_memory, hf_quantizer, dtype, keep_in_fp32_regex)
  4380. # Finalize model weight initialization
  4381. if from_tf:
  4382. model, loading_info = cls._load_from_tf(model, config, checkpoint_files)
  4383. elif from_flax:
  4384. model = cls._load_from_flax(model, checkpoint_files)
  4385. elif from_pt:
  4386. # restore default dtype
  4387. if dtype_orig is not None:
  4388. torch.set_default_dtype(dtype_orig)
  4389. (
  4390. model,
  4391. missing_keys,
  4392. unexpected_keys,
  4393. mismatched_keys,
  4394. offload_index,
  4395. error_msgs,
  4396. ) = cls._load_pretrained_model(
  4397. model,
  4398. state_dict,
  4399. checkpoint_files,
  4400. pretrained_model_name_or_path,
  4401. ignore_mismatched_sizes=ignore_mismatched_sizes,
  4402. sharded_metadata=sharded_metadata,
  4403. device_map=device_map,
  4404. disk_offload_folder=offload_folder,
  4405. dtype=dtype,
  4406. hf_quantizer=hf_quantizer,
  4407. keep_in_fp32_regex=keep_in_fp32_regex,
  4408. device_mesh=device_mesh,
  4409. key_mapping=key_mapping,
  4410. weights_only=weights_only,
  4411. )
  4412. # make sure token embedding weights are still tied if needed
  4413. model.tie_weights()
  4414. # Set model in evaluation mode to deactivate DropOut modules by default
  4415. model.eval()
  4416. # check if using kernels
  4417. if use_kernels:
  4418. model.use_kernels = True
  4419. # If it is a model with generation capabilities, attempt to load generation files (generation config,
  4420. # custom generate function)
  4421. if model.can_generate() and generation_config is not None:
  4422. logger.info("The user-defined `generation_config` will be used to override the default generation config.")
  4423. model.generation_config = model.generation_config.from_dict(generation_config.to_dict())
  4424. elif model.can_generate() and pretrained_model_name_or_path is not None:
  4425. repo_loading_kwargs = {
  4426. "cache_dir": cache_dir,
  4427. "force_download": force_download,
  4428. "proxies": proxies,
  4429. "local_files_only": local_files_only,
  4430. "token": token,
  4431. "revision": revision,
  4432. "subfolder": subfolder,
  4433. **kwargs,
  4434. }
  4435. # Load generation config
  4436. try:
  4437. model.generation_config = GenerationConfig.from_pretrained(
  4438. pretrained_model_name_or_path,
  4439. _from_auto=from_auto_class,
  4440. _from_pipeline=from_pipeline,
  4441. **repo_loading_kwargs,
  4442. )
  4443. except OSError:
  4444. logger.info(
  4445. "Generation config file not found, using a generation config created from the model config."
  4446. )
  4447. pass
  4448. # Load custom generate function if `pretrained_model_name_or_path` defines it (and override `generate`)
  4449. if hasattr(model, "load_custom_generate"):
  4450. try:
  4451. custom_generate = model.load_custom_generate(
  4452. pretrained_model_name_or_path, trust_remote_code=trust_remote_code, **repo_loading_kwargs
  4453. )
  4454. model.generate = functools.partial(custom_generate, model=model)
  4455. except OSError: # there is no custom generate function
  4456. pass
  4457. # Dispatch model with hooks on all devices if necessary (not needed with a tp_plan, so we skip it as it slightly
  4458. # harm performances)
  4459. if device_map is not None and device_mesh is None:
  4460. device_map_kwargs = {
  4461. "device_map": device_map,
  4462. "offload_dir": offload_folder,
  4463. "offload_index": offload_index,
  4464. "offload_buffers": offload_buffers,
  4465. }
  4466. if "skip_keys" in inspect.signature(dispatch_model).parameters:
  4467. device_map_kwargs["skip_keys"] = model._skip_keys_device_placement
  4468. # For HQQ method we force-set the hooks for single GPU envs
  4469. if (
  4470. "force_hooks" in inspect.signature(dispatch_model).parameters
  4471. and hf_quantizer is not None
  4472. and hf_quantizer.quantization_config.quant_method == QuantizationMethod.HQQ
  4473. ):
  4474. device_map_kwargs["force_hooks"] = True
  4475. if (
  4476. hf_quantizer is not None
  4477. and hf_quantizer.quantization_config.quant_method == QuantizationMethod.FBGEMM_FP8
  4478. and isinstance(device_map, dict)
  4479. and ("cpu" in device_map.values() or "disk" in device_map.values())
  4480. ):
  4481. device_map_kwargs["offload_buffers"] = True
  4482. if not is_fsdp_enabled() and not is_deepspeed_zero3_enabled():
  4483. dispatch_model(model, **device_map_kwargs)
  4484. if hf_quantizer is not None:
  4485. model.hf_quantizer = hf_quantizer
  4486. hf_quantizer.postprocess_model(model, config=config)
  4487. if _adapter_model_path is not None:
  4488. adapter_kwargs["key_mapping"] = key_mapping
  4489. model.load_adapter(
  4490. _adapter_model_path,
  4491. adapter_name=adapter_name,
  4492. token=token,
  4493. adapter_kwargs=adapter_kwargs,
  4494. )
  4495. if output_loading_info:
  4496. if from_pt:
  4497. loading_info = {
  4498. "missing_keys": missing_keys,
  4499. "unexpected_keys": unexpected_keys,
  4500. "mismatched_keys": mismatched_keys,
  4501. "error_msgs": error_msgs,
  4502. }
  4503. elif from_flax:
  4504. loading_info = None
  4505. return model, loading_info
  4506. return model
  4507. @staticmethod
  4508. def _fix_state_dict_key_on_load(key: str) -> tuple[str, bool]:
  4509. """Replace legacy parameter names with their modern equivalents. E.g. beta -> bias, gamma -> weight."""
  4510. # Rename LayerNorm beta & gamma params for some early models ported from Tensorflow (e.g. Bert)
  4511. # This rename is logged.
  4512. if key.endswith("LayerNorm.beta"):
  4513. return key.replace("LayerNorm.beta", "LayerNorm.bias"), True
  4514. if key.endswith("LayerNorm.gamma"):
  4515. return key.replace("LayerNorm.gamma", "LayerNorm.weight"), True
  4516. # Rename weight norm parametrizations to match changes across torch versions.
  4517. # Impacts a number of speech/wav2vec models. e.g. Hubert, Wav2Vec2, and others.
  4518. # This rename is not logged.
  4519. if hasattr(nn.utils.parametrizations, "weight_norm"):
  4520. if key.endswith("weight_g"):
  4521. return key.replace("weight_g", "parametrizations.weight.original0"), True
  4522. if key.endswith("weight_v"):
  4523. return key.replace("weight_v", "parametrizations.weight.original1"), True
  4524. else:
  4525. if key.endswith("parametrizations.weight.original0"):
  4526. return key.replace("parametrizations.weight.original0", "weight_g"), True
  4527. if key.endswith("parametrizations.weight.original1"):
  4528. return key.replace("parametrizations.weight.original1", "weight_v"), True
  4529. return key, False
  4530. def _get_key_renaming_mapping(
  4531. self,
  4532. checkpoint_keys: list[str],
  4533. key_mapping: Optional[dict[str, str]] = None,
  4534. loading_base_model_from_task_state_dict: bool = False,
  4535. loading_task_model_from_base_state_dict: bool = False,
  4536. ):
  4537. """
  4538. Compute a mapping between the serialized keys on disk `checkpoint_keys`, and the keys that the model
  4539. that we are loading expects. This is the single entry point for key renaming that will be used during
  4540. loading.
  4541. Log if any parameters have been renamed.
  4542. """
  4543. prefix = self.base_model_prefix
  4544. _prefix = f"{prefix}."
  4545. if loading_task_model_from_base_state_dict:
  4546. task_specific_expected_keys, base_model_keys = [], []
  4547. for key in self.state_dict():
  4548. if key.startswith(_prefix):
  4549. base_model_keys.append(key[len(_prefix) :])
  4550. else:
  4551. task_specific_expected_keys.append(key)
  4552. renamed_keys = {}
  4553. key_renaming_mapping = {}
  4554. for key in checkpoint_keys:
  4555. # Class specific rename
  4556. new_key, has_changed = self._fix_state_dict_key_on_load(key)
  4557. # Optionally map the key according to `key_mapping`
  4558. if key_mapping is not None:
  4559. for pattern, replacement in key_mapping.items():
  4560. new_key, n_replace = re.subn(pattern, replacement, new_key)
  4561. # Early exit of the loop
  4562. if n_replace > 0:
  4563. has_changed = True
  4564. break
  4565. # In this case, we need to add the prefix to the keys, to match them to the expected keys
  4566. if loading_task_model_from_base_state_dict:
  4567. # small sanity check: if we find a key that is only part of the task-specific keys, we raise
  4568. # (if it's also part of the base model, we do not raise and assume it comes from there)
  4569. if new_key in task_specific_expected_keys and new_key not in base_model_keys:
  4570. raise ValueError(
  4571. "The state dictionary of the model you are trying to load is corrupted. Are you sure it was "
  4572. "properly saved?"
  4573. )
  4574. new_key = ".".join([prefix, new_key])
  4575. # In this case we need to remove the prefix from the key to match them to the expected keys, and use
  4576. # only the keys starting with the prefix
  4577. elif loading_base_model_from_task_state_dict:
  4578. if not new_key.startswith(_prefix):
  4579. continue
  4580. new_key = new_key[len(_prefix) :]
  4581. key_renaming_mapping[key] = new_key
  4582. # track gamma/beta rename for logging
  4583. if has_changed:
  4584. if key.endswith("LayerNorm.gamma"):
  4585. renamed_keys["LayerNorm.gamma"] = (key, new_key)
  4586. elif key.endswith("LayerNorm.beta"):
  4587. renamed_keys["LayerNorm.beta"] = (key, new_key)
  4588. if renamed_keys:
  4589. warning_msg = f"A pretrained model of type `{self.__class__.__name__}` "
  4590. warning_msg += "contains parameters that have been renamed internally (a few are listed below but more are present in the model):\n"
  4591. for old_key, new_key in renamed_keys.values():
  4592. warning_msg += f"* `{old_key}` -> `{new_key}`\n"
  4593. warning_msg += "If you are using a model from the Hub, consider submitting a PR to adjust these weights and help future users."
  4594. logger.info_once(warning_msg)
  4595. return key_renaming_mapping
  4596. @staticmethod
  4597. def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
  4598. """
  4599. Similar to `_fix_state_dict_key_on_load` allows to define hook for state dict key renaming on model save.
  4600. Do nothing by default, but can be overridden in particular models.
  4601. """
  4602. return key, False
  4603. def _fix_state_dict_keys_on_save(self, state_dict):
  4604. """
  4605. Similar to `_fix_state_dict_keys_on_load` allows to define hook for state dict key renaming on model save.
  4606. Apply `_fix_state_dict_key_on_save` to all keys in `state_dict`.
  4607. """
  4608. return {self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items()}
  4609. @classmethod
  4610. def _load_pretrained_model(
  4611. cls,
  4612. model: "PreTrainedModel",
  4613. state_dict: Optional[dict],
  4614. checkpoint_files: Optional[list[str]],
  4615. pretrained_model_name_or_path: Optional[str],
  4616. ignore_mismatched_sizes: bool = False,
  4617. sharded_metadata: Optional[dict] = None,
  4618. device_map: Optional[dict] = None,
  4619. disk_offload_folder: Optional[str] = None,
  4620. dtype: Optional[torch.dtype] = None,
  4621. hf_quantizer: Optional[HfQuantizer] = None,
  4622. keep_in_fp32_regex: Optional[re.Pattern] = None,
  4623. device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
  4624. key_mapping: Optional[dict[str, str]] = None,
  4625. weights_only: bool = True,
  4626. ):
  4627. # TODO: we should only be calling hf_quantizer.skip_placement or something like that
  4628. is_quantized = hf_quantizer is not None
  4629. is_hqq_or_quark = is_quantized and hf_quantizer.quantization_config.quant_method in {
  4630. QuantizationMethod.HQQ,
  4631. QuantizationMethod.QUARK,
  4632. }
  4633. # Get all the keys of the state dicts that we have to initialize the model
  4634. if sharded_metadata is not None:
  4635. original_checkpoint_keys = sharded_metadata["all_checkpoint_keys"]
  4636. elif state_dict is not None:
  4637. original_checkpoint_keys = list(state_dict.keys())
  4638. else:
  4639. original_checkpoint_keys = list(
  4640. load_state_dict(checkpoint_files[0], map_location="meta", weights_only=weights_only).keys()
  4641. )
  4642. # Check if we are in a special state, i.e. loading from a state dict coming from a different architecture
  4643. prefix = model.base_model_prefix
  4644. has_prefix_module = any(s.startswith(prefix) for s in original_checkpoint_keys) if len(prefix) > 0 else False
  4645. expects_prefix_module = hasattr(model, prefix) if len(prefix) > 0 else False
  4646. loading_task_model_from_base_state_dict = not has_prefix_module and expects_prefix_module
  4647. loading_base_model_from_task_state_dict = has_prefix_module and not expects_prefix_module
  4648. # Find the key names that the model expects from the serialized keys
  4649. key_renaming_mapping = model._get_key_renaming_mapping(
  4650. original_checkpoint_keys,
  4651. key_mapping,
  4652. loading_base_model_from_task_state_dict,
  4653. loading_task_model_from_base_state_dict,
  4654. )
  4655. checkpoint_keys = list(key_renaming_mapping.values())
  4656. # Find missing and unexpected keys from the state dict
  4657. missing_keys, unexpected_keys = _find_missing_and_unexpected_keys(
  4658. model, original_checkpoint_keys, checkpoint_keys, loading_base_model_from_task_state_dict, hf_quantizer
  4659. )
  4660. # Find all the keys with shape mismatch (if we ignore the mismatch, the weights need to be newly initialized the
  4661. # same way as missing keys)
  4662. mismatched_keys, mismatched_shapes = _find_mismatched_keys(
  4663. model,
  4664. state_dict,
  4665. checkpoint_files,
  4666. ignore_mismatched_sizes,
  4667. key_renaming_mapping,
  4668. is_quantized,
  4669. weights_only,
  4670. )
  4671. # We need to update both the mapping and the list of checkpoint keys to remove the mismatched and unexpected ones
  4672. key_renaming_mapping = {
  4673. k: v for k, v in key_renaming_mapping.items() if v not in mismatched_keys and v not in unexpected_keys
  4674. }
  4675. checkpoint_keys = list(key_renaming_mapping.values())
  4676. # Move missing (and potentially mismatched) keys back to cpu from meta device (because they won't be moved when
  4677. # loading the weights as they are not in the loaded state dict)
  4678. model._move_missing_keys_from_meta_to_cpu(missing_keys + mismatched_keys, dtype, hf_quantizer)
  4679. # correctly initialize the missing (and potentially mismatched) keys
  4680. model._initialize_missing_keys(missing_keys + mismatched_keys, is_quantized)
  4681. # Set some modules to fp32 if needed
  4682. if keep_in_fp32_regex is not None:
  4683. for name, param in model.named_parameters():
  4684. if keep_in_fp32_regex.search(name):
  4685. # param = param.to(torch.float32) does not work here as only in the local scope.
  4686. param.data = param.data.to(torch.float32)
  4687. # Get reverse key mapping
  4688. reverse_key_renaming_mapping = {v: k for k, v in key_renaming_mapping.items()}
  4689. is_offloaded_safetensors = False
  4690. # This offload index if for params explicitly on the "disk" in the device_map
  4691. disk_offload_index = None
  4692. disk_only_shard_files = []
  4693. # Prepare parameters offloading if needed
  4694. if device_map is not None and "disk" in device_map.values():
  4695. if disk_offload_folder is not None:
  4696. os.makedirs(disk_offload_folder, exist_ok=True)
  4697. is_offloaded_safetensors = checkpoint_files is not None and checkpoint_files[0].endswith(".safetensors")
  4698. if disk_offload_folder is None and not is_offloaded_safetensors:
  4699. raise ValueError(
  4700. "The current `device_map` had weights offloaded to the disk. Please provide an `offload_folder`"
  4701. " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
  4702. " offers the weights in this format."
  4703. )
  4704. if is_offloaded_safetensors:
  4705. param_device_map = expand_device_map(device_map, checkpoint_keys)
  4706. str_dtype = str(dtype).replace("torch.", "") if dtype is not None else "float32"
  4707. if sharded_metadata is None:
  4708. weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
  4709. else:
  4710. folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
  4711. # Fix the weight map keys according to the key mapping
  4712. weight_map = {
  4713. key_renaming_mapping[k]: v
  4714. for k, v in sharded_metadata["weight_map"].items()
  4715. if k in key_renaming_mapping
  4716. }
  4717. weight_map = {k: os.path.join(folder, v) for k, v in weight_map.items()}
  4718. # Find potential checkpoints containing only offloaded weights
  4719. disk_only_shard_files = get_disk_only_shard_files(device_map, weight_map)
  4720. disk_offload_index = {
  4721. name: {
  4722. "safetensors_file": file,
  4723. "weight_name": reverse_key_renaming_mapping[name],
  4724. "dtype": str_dtype,
  4725. }
  4726. for name, file in weight_map.items()
  4727. if param_device_map[name] == "disk"
  4728. }
  4729. else:
  4730. disk_offload_index = {}
  4731. # To be able to iterate, even if we don't use it if the state_dict is already provided
  4732. elif state_dict is not None:
  4733. checkpoint_files = [""]
  4734. # Compute expected model keys
  4735. expected_keys = list(model.state_dict().keys())
  4736. if hf_quantizer is not None:
  4737. expected_keys = hf_quantizer.update_expected_keys(model, expected_keys, checkpoint_keys)
  4738. if logger.level >= logging.WARNING:
  4739. verify_tp_plan(expected_keys, getattr(model, "_tp_plan", None))
  4740. # Warmup cuda to load the weights much faster on devices
  4741. if device_map is not None and not is_hqq_or_quark:
  4742. expanded_device_map = expand_device_map(device_map, expected_keys)
  4743. caching_allocator_warmup(model, expanded_device_map, hf_quantizer)
  4744. # Prepare and compatabilize arguments for serial and parallel shard loading
  4745. args_list = [
  4746. (
  4747. shard_file,
  4748. state_dict,
  4749. disk_only_shard_files,
  4750. is_quantized,
  4751. device_map,
  4752. hf_quantizer,
  4753. key_renaming_mapping,
  4754. weights_only,
  4755. model,
  4756. reverse_key_renaming_mapping,
  4757. disk_offload_folder,
  4758. disk_offload_index,
  4759. keep_in_fp32_regex,
  4760. device_mesh,
  4761. )
  4762. for shard_file in checkpoint_files
  4763. ]
  4764. error_msgs = []
  4765. if (
  4766. os.environ.get("HF_ENABLE_PARALLEL_LOADING", "").upper() in ENV_VARS_TRUE_VALUES
  4767. and not is_deepspeed_zero3_enabled()
  4768. ):
  4769. _error_msgs, disk_offload_index = load_shard_files_with_threadpool(args_list)
  4770. error_msgs += _error_msgs
  4771. else:
  4772. if len(args_list) > 1:
  4773. args_list = logging.tqdm(args_list, desc="Loading checkpoint shards")
  4774. for args in args_list:
  4775. _error_msgs, disk_offload_index = load_shard_file(args)
  4776. error_msgs += _error_msgs
  4777. # Save offloaded index if needed
  4778. if disk_offload_index is not None and len(disk_offload_index) > 0 and not is_offloaded_safetensors:
  4779. save_offload_index(disk_offload_index, disk_offload_folder)
  4780. disk_offload_index = None
  4781. # Post-processing for tensor parallelism
  4782. if device_mesh is not None:
  4783. # When using TP, the device map is a single device for all parameters
  4784. tp_device = list(device_map.values())[0]
  4785. # This is needed for the RotaryEmbedding, which was not initialized on the correct device as it is
  4786. # not part of the state_dict (persistent=False)
  4787. for buffer in model.buffers():
  4788. if buffer.device != tp_device:
  4789. buffer.data = buffer.to(tp_device)
  4790. # In this case, the top-most task module weights were not moved to device and parallelized as they
  4791. # were not part of the loaded weights: do it now
  4792. if loading_task_model_from_base_state_dict:
  4793. parameters_to_initialize = {
  4794. name: param for name, param in model.named_parameters() if not name.startswith(prefix)
  4795. }
  4796. for name, param in parameters_to_initialize.items():
  4797. # If it is still on meta here, it means that it's a tied weight that will be tied later anyway -> skip it
  4798. if param.device.type == "meta":
  4799. continue
  4800. # Shard the param
  4801. to_contiguous, casting_dtype = _infer_parameter_dtype(model, name, param, keep_in_fp32_regex)
  4802. shard_and_distribute_module(
  4803. model,
  4804. param.to(tp_device),
  4805. param,
  4806. name,
  4807. casting_dtype,
  4808. to_contiguous,
  4809. device_mesh.get_local_rank(),
  4810. device_mesh,
  4811. )
  4812. # Remove potential model-specific exceptions from the warnings
  4813. missing_keys, unexpected_keys = model._adjust_missing_and_unexpected_keys(
  4814. missing_keys, unexpected_keys, loading_task_model_from_base_state_dict
  4815. )
  4816. # All potential warnings/infos
  4817. if len(error_msgs) > 0:
  4818. error_msg = "\n\t".join(error_msgs)
  4819. if "size mismatch" in error_msg:
  4820. error_msg += (
  4821. "\n\tYou may consider adding `ignore_mismatched_sizes=True` in the model `from_pretrained` method."
  4822. )
  4823. raise RuntimeError(f"Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}")
  4824. if len(unexpected_keys) > 0:
  4825. archs = [] if model.config.architectures is None else model.config.architectures
  4826. warner = logger.warning if model.__class__.__name__ in archs else logger.info
  4827. warner(
  4828. f"Some weights of the model checkpoint at {pretrained_model_name_or_path} were not used when"
  4829. f" initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are"
  4830. f" initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or"
  4831. " with another architecture (e.g. initializing a BertForSequenceClassification model from a"
  4832. " BertForPreTraining model).\n- This IS NOT expected if you are initializing"
  4833. f" {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical"
  4834. " (initializing a BertForSequenceClassification model from a BertForSequenceClassification model)."
  4835. )
  4836. if len(missing_keys) > 0:
  4837. logger.warning(
  4838. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  4839. f" {pretrained_model_name_or_path} and are newly initialized: {missing_keys}\nYou should probably"
  4840. " TRAIN this model on a down-stream task to be able to use it for predictions and inference."
  4841. )
  4842. if len(mismatched_keys) > 0:
  4843. mismatched_warning = "\n".join(
  4844. [
  4845. f"- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated"
  4846. for key, (shape1, shape2) in zip(mismatched_keys, mismatched_shapes)
  4847. ]
  4848. )
  4849. logger.warning(
  4850. f"Some weights of {model.__class__.__name__} were not initialized from the model checkpoint at"
  4851. f" {pretrained_model_name_or_path} and are newly initialized because the shapes did not"
  4852. f" match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able"
  4853. " to use it for predictions and inference."
  4854. )
  4855. return model, missing_keys, unexpected_keys, mismatched_keys, disk_offload_index, error_msgs
  4856. @classmethod
  4857. def _load_from_tf(cls, model, config, checkpoint_files):
  4858. if checkpoint_files[0].endswith(".index"):
  4859. # Load from a TensorFlow 1.X checkpoint - provided by original authors
  4860. model = cls.load_tf_weights(model, config, checkpoint_files[0][:-6]) # Remove the '.index'
  4861. loading_info = None
  4862. else:
  4863. # Load from our TensorFlow 2.0 checkpoints
  4864. try:
  4865. from .modeling_tf_pytorch_utils import load_tf2_checkpoint_in_pytorch_model
  4866. model, loading_info = load_tf2_checkpoint_in_pytorch_model(
  4867. model, checkpoint_files[0], allow_missing_keys=True, output_loading_info=True
  4868. )
  4869. except ImportError:
  4870. logger.error(
  4871. "Loading a TensorFlow model in PyTorch, requires both PyTorch and TensorFlow to be installed."
  4872. " Please see https://pytorch.org/ and https://www.tensorflow.org/install/ for installation"
  4873. " instructions."
  4874. )
  4875. raise
  4876. return model, loading_info
  4877. @classmethod
  4878. def _load_from_flax(cls, model, checkpoint_files):
  4879. try:
  4880. from .modeling_flax_pytorch_utils import load_flax_checkpoint_in_pytorch_model
  4881. model = load_flax_checkpoint_in_pytorch_model(model, checkpoint_files[0])
  4882. except ImportError:
  4883. logger.error(
  4884. "Loading a Flax model in PyTorch, requires both PyTorch and Flax to be installed. Please see"
  4885. " https://pytorch.org/ and https://flax.readthedocs.io/en/latest/installation.html for"
  4886. " installation instructions."
  4887. )
  4888. raise
  4889. return model
  4890. def retrieve_modules_from_names(self, names, add_prefix=False, remove_prefix=False):
  4891. module_keys = {".".join(key.split(".")[:-1]) for key in names}
  4892. # torch.nn.ParameterList is a special case where two parameter keywords
  4893. # are appended to the module name, *e.g.* bert.special_embeddings.0
  4894. module_keys = module_keys.union(
  4895. {".".join(key.split(".")[:-2]) for key in names if len(key) > 0 and key[-1].isdigit()}
  4896. )
  4897. retrieved_modules = []
  4898. # retrieve all modules that has at least one missing weight name
  4899. for name, module in self.named_modules():
  4900. if remove_prefix:
  4901. _prefix = f"{self.base_model_prefix}."
  4902. name = name.removeprefix(_prefix)
  4903. elif add_prefix:
  4904. name = ".".join([self.base_model_prefix, name]) if len(name) > 0 else self.base_model_prefix
  4905. if name in module_keys:
  4906. retrieved_modules.append(module)
  4907. return retrieved_modules
  4908. @classmethod
  4909. def register_for_auto_class(cls, auto_class="AutoModel"):
  4910. """
  4911. Register this class with a given auto class. This should only be used for custom models as the ones in the
  4912. library are already mapped with an auto class.
  4913. Args:
  4914. auto_class (`str` or `type`, *optional*, defaults to `"AutoModel"`):
  4915. The auto class to register this new model with.
  4916. """
  4917. if not isinstance(auto_class, str):
  4918. auto_class = auto_class.__name__
  4919. import transformers.models.auto as auto_module
  4920. if not hasattr(auto_module, auto_class):
  4921. raise ValueError(f"{auto_class} is not a valid auto class.")
  4922. cls._auto_class = auto_class
  4923. def to_bettertransformer(self) -> "PreTrainedModel":
  4924. """
  4925. Converts the model to use [PyTorch's native attention
  4926. implementation](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html), integrated to
  4927. Transformers through [Optimum library](https://huggingface.co/docs/optimum/bettertransformer/overview). Only a
  4928. subset of all Transformers models are supported.
  4929. PyTorch's attention fastpath allows to speed up inference through kernel fusions and the use of [nested
  4930. tensors](https://pytorch.org/docs/stable/nested.html). Detailed benchmarks can be found in [this blog
  4931. post](https://medium.com/pytorch/bettertransformer-out-of-the-box-performance-for-huggingface-transformers-3fbe27d50ab2).
  4932. Returns:
  4933. [`PreTrainedModel`]: The model converted to BetterTransformer.
  4934. """
  4935. if not is_optimum_available():
  4936. raise ImportError("The package `optimum` is required to use Better Transformer.")
  4937. from optimum.version import __version__ as optimum_version
  4938. if version.parse(optimum_version) < version.parse("1.7.0"):
  4939. raise ImportError(
  4940. f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
  4941. )
  4942. from optimum.bettertransformer import BetterTransformer
  4943. return BetterTransformer.transform(self)
  4944. def reverse_bettertransformer(self):
  4945. """
  4946. Reverts the transformation from [`~PreTrainedModel.to_bettertransformer`] so that the original modeling is
  4947. used, for example in order to save the model.
  4948. Returns:
  4949. [`PreTrainedModel`]: The model converted back to the original modeling.
  4950. """
  4951. if not is_optimum_available():
  4952. raise ImportError("The package `optimum` is required to use Better Transformer.")
  4953. from optimum.version import __version__ as optimum_version
  4954. if version.parse(optimum_version) < version.parse("1.7.0"):
  4955. raise ImportError(
  4956. f"Please install optimum>=1.7.0 to use Better Transformer. The version {optimum_version} was found."
  4957. )
  4958. from optimum.bettertransformer import BetterTransformer
  4959. return BetterTransformer.reverse(self)
  4960. def warn_if_padding_and_no_attention_mask(self, input_ids, attention_mask):
  4961. """
  4962. Shows a one-time warning if the input_ids appear to contain padding and no attention mask was given.
  4963. """
  4964. # Skip the check during tracing.
  4965. if is_torch_fx_proxy(input_ids) or torch.jit.is_tracing() or is_torchdynamo_compiling():
  4966. return
  4967. if (attention_mask is not None) or (self.config.pad_token_id is None):
  4968. return
  4969. # Check only the first and last input IDs to reduce overhead.
  4970. if self.config.pad_token_id in input_ids[:, [-1, 0]]:
  4971. warn_string = (
  4972. "We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See "
  4973. "https://huggingface.co/docs/transformers/troubleshooting"
  4974. "#incorrect-output-when-padding-tokens-arent-masked."
  4975. )
  4976. # If the pad token is equal to either BOS, EOS, or SEP, we do not know whether the user should use an
  4977. # attention_mask or not. In this case, we should still show a warning because this is a rare case.
  4978. if (
  4979. (self.config.bos_token_id is not None and self.config.bos_token_id == self.config.pad_token_id)
  4980. or (self.config.eos_token_id is not None and self.config.eos_token_id == self.config.pad_token_id)
  4981. or (self.config.sep_token_id is not None and self.config.sep_token_id == self.config.pad_token_id)
  4982. ):
  4983. warn_string += (
  4984. f"\nYou may ignore this warning if your `pad_token_id` ({self.config.pad_token_id}) is identical "
  4985. f"to the `bos_token_id` ({self.config.bos_token_id}), `eos_token_id` ({self.config.eos_token_id}), "
  4986. f"or the `sep_token_id` ({self.config.sep_token_id}), and your input is not padded."
  4987. )
  4988. logger.warning_once(warn_string)
  4989. @property
  4990. def supports_tp_plan(self):
  4991. """
  4992. Returns whether the model has a tensor parallelism plan.
  4993. """
  4994. if self._tp_plan is not None:
  4995. return True
  4996. # Check if base model has a TP plan
  4997. if getattr(self.base_model, "_tp_plan", None) is not None:
  4998. return True
  4999. if self.config.base_model_tp_plan is not None:
  5000. return True
  5001. return False
  5002. @property
  5003. def tp_size(self):
  5004. """
  5005. Returns the model's tensor parallelism degree.
  5006. """
  5007. # if None, the model didn't undergo tensor parallel sharding
  5008. return self._tp_size
  5009. @property
  5010. def supports_pp_plan(self):
  5011. if self._pp_plan is not None:
  5012. return True
  5013. # Check if base model has PP plan
  5014. if getattr(self.base_model, "_pp_plan", None) is not None:
  5015. return True
  5016. return False
  5017. @property
  5018. def loss_function(self):
  5019. if hasattr(self, "_loss_function"):
  5020. return self._loss_function
  5021. loss_type = getattr(self, "loss_type", None)
  5022. if loss_type is None or loss_type not in LOSS_MAPPING:
  5023. logger.warning_once(
  5024. f"`loss_type={loss_type}` was set in the config but it is unrecognized. "
  5025. f"Using the default loss: `ForCausalLMLoss`."
  5026. )
  5027. loss_type = "ForCausalLM"
  5028. return LOSS_MAPPING[loss_type]
  5029. @loss_function.setter
  5030. def loss_function(self, value):
  5031. self._loss_function = value
  5032. def kernelize(self):
  5033. if not is_kernels_available():
  5034. raise ValueError(
  5035. "Kernels are not available. To use kernels, please install kernels using `pip install kernels`"
  5036. )
  5037. from kernels import Device, Mode, kernelize
  5038. mode = Mode.INFERENCE if not self.training else Mode.TRAINING
  5039. kernelize(self, device=Device(type=self.device.type), mode=mode)
  5040. self._use_kernels = True
  5041. @property
  5042. def use_kernels(self) -> bool:
  5043. return getattr(self, "_use_kernels", False)
  5044. @use_kernels.setter
  5045. def use_kernels(self, value: bool) -> None:
  5046. # Avoid re-kernelizing if already enabled
  5047. if bool(value) and getattr(self, "_use_kernels", False):
  5048. return
  5049. if value:
  5050. self.kernelize()
  5051. else:
  5052. if getattr(self, "_use_kernels", False):
  5053. logger.warning_once(
  5054. "Disabling kernels at runtime is a no-op as there is no 'unkernelize' routine; keeping current kernels active."
  5055. )
  5056. self._use_kernels = False
  5057. def get_compiled_call(self, compile_config: Optional[CompileConfig]) -> Callable:
  5058. """Return a `torch.compile`'d version of `self.__call__`. This is useful to dynamically choose between
  5059. non-compiled/compiled `forward` during inference, especially to switch between prefill (where we don't
  5060. want to use compiled version to avoid recomputing the graph with new shapes) and iterative decoding
  5061. (where we want the speed-ups of compiled version with static shapes)."""
  5062. # Only reset it if not present or different from previous config
  5063. if "llama4" in self.config.model_type: # TODO try to enable for FULL COMPILE HYBRID CACHE SUPPORT
  5064. return self.__call__
  5065. compile_config = compile_config or CompileConfig()
  5066. default_config = getattr(self.generation_config, "compile_config", None) or CompileConfig()
  5067. if (
  5068. not hasattr(self, "_compiled_call")
  5069. or getattr(self, "_last_compile_config", default_config) != compile_config
  5070. ):
  5071. self._last_compile_config = compile_config
  5072. self._compiled_call = torch.compile(self.__call__, **compile_config.to_dict())
  5073. return self._compiled_call
  5074. @classmethod
  5075. def is_backend_compatible(cls):
  5076. return cls._supports_attention_backend
  5077. def _move_missing_keys_from_meta_to_cpu(
  5078. self, missing_keys: list[str], dtype: torch.dtype, hf_quantizer: Optional[HfQuantizer]
  5079. ) -> None:
  5080. """Move the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts) back
  5081. from meta device to cpu.
  5082. """
  5083. is_quantized = hf_quantizer is not None
  5084. # In this case we need to move everything back
  5085. if is_fsdp_enabled() and not is_local_dist_rank_0() and not is_quantized:
  5086. # We only do it for the parameters, as the buffers are not initialized on the meta device by default
  5087. for key, param in self.named_parameters():
  5088. value = torch.empty_like(param, dtype=dtype, device="cpu")
  5089. _load_parameter_into_model(self, key, value)
  5090. return
  5091. model_state_dict = self.state_dict()
  5092. for key in missing_keys:
  5093. param = model_state_dict[key]
  5094. # Buffers are not initialized on the meta device, so we still need this check to avoid overwriting them
  5095. if param.device == torch.device("meta"):
  5096. value = torch.empty_like(param, dtype=dtype, device="cpu")
  5097. if not is_quantized or not hf_quantizer.param_needs_quantization(self, key):
  5098. _load_parameter_into_model(self, key, value)
  5099. else:
  5100. hf_quantizer.create_quantized_param(self, value, key, "cpu")
  5101. def _initialize_missing_keys(self, missing_keys: list[str], is_quantized: bool) -> None:
  5102. """Initialize the missing keys (keys that are part of the model parameters, but were NOT found in the loaded state dicts), according to
  5103. `_initialize_weights`. Indeed, since the corresponding weights are missing from the state dict, they will not be replaced and need to
  5104. be initialized correctly (i.e. weight initialization distribution).
  5105. Also take care of setting the `_is_hf_initialized` flag for keys that are not missing.
  5106. """
  5107. for key in self.state_dict():
  5108. # If it's part of the keys that will be loaded, mark it as already initialized
  5109. if key not in missing_keys:
  5110. param_or_buffer = self.get_parameter_or_buffer(key)
  5111. param_or_buffer._is_hf_initialized = True
  5112. def set_is_initialized_for_modules(module):
  5113. # A module is already initialized if and only if all its children are also already initialized, and all
  5114. # its immediate `nn.Parameter` and persistent buffers are also already initialized
  5115. if (
  5116. all(getattr(child, "_is_hf_initialized", False) for child in module.children())
  5117. and all(getattr(param, "_is_hf_initialized", False) for param in module.parameters(recurse=False))
  5118. and all(
  5119. getattr(buffer, "_is_hf_initialized", False)
  5120. for buffer in module.buffers(recurse=False)
  5121. if buffer not in module._non_persistent_buffers_set
  5122. )
  5123. ):
  5124. module._is_hf_initialized = True
  5125. # Set the flag on the modules as well. We do it recursively (depth-first), as it's more efficient (we do not
  5126. # need to check the entire state dict of each module, only the immediate children, so we only iterate once over
  5127. # each param)
  5128. self.apply(set_is_initialized_for_modules)
  5129. # This will only initialize submodules that are not marked as initialized by the line above.
  5130. if is_deepspeed_zero3_enabled() and not is_quantized:
  5131. import deepspeed
  5132. not_initialized_parameters = list(
  5133. {v for v in self.state_dict().values() if not getattr(v, "_is_hf_initialized", False)}
  5134. )
  5135. with deepspeed.zero.GatheredParameters(not_initialized_parameters, modifier_rank=0):
  5136. self.initialize_weights()
  5137. else:
  5138. self.initialize_weights()
  5139. def _adjust_missing_and_unexpected_keys(
  5140. self, missing_keys: list[str], unexpected_keys: list[str], loading_task_model_from_base_state_dict: bool
  5141. ) -> tuple[list[str], list[str]]:
  5142. """Adjust the `missing_keys` and `unexpected_keys` based on current model's exception rules, to avoid
  5143. raising unneeded warnings/errors.
  5144. """
  5145. # Old checkpoints may have keys for rotary_emb.inv_freq for each layer, however we moved this buffer to the main model
  5146. # (so the buffer name has changed). Remove them in such a case. This is another exception that was not added to
  5147. # `_keys_to_ignore_on_load_unexpected` as it touches many models -> we add it manually to the existing patterns
  5148. has_inv_freq_buffers = any(buffer.endswith("rotary_emb.inv_freq") for buffer, _ in self.named_buffers())
  5149. additional_unexpected_patterns = [r"rotary_emb\.inv_freq"] if has_inv_freq_buffers else []
  5150. missing_patterns = self._keys_to_ignore_on_load_missing or []
  5151. unexpected_patterns = (self._keys_to_ignore_on_load_unexpected or []) + additional_unexpected_patterns
  5152. ignore_missing_regex, ignore_unexpected_regex = None, None
  5153. if len(missing_patterns) > 0:
  5154. ignore_missing_regex = re.compile("|".join(rf"({pattern})" for pattern in missing_patterns))
  5155. if len(unexpected_patterns) > 0:
  5156. ignore_unexpected_regex = re.compile("|".join(rf"({pattern})" for pattern in unexpected_patterns))
  5157. # Clean-up missing keys
  5158. if ignore_missing_regex is not None:
  5159. missing_keys = [key for key in missing_keys if ignore_missing_regex.search(key) is None]
  5160. # Clean-up unexpected keys
  5161. if ignore_unexpected_regex is not None:
  5162. unexpected_keys = [key for key in unexpected_keys if ignore_unexpected_regex.search(key) is None]
  5163. # Note: only the unexpected keys should remove the added prefix here, to correctly display the original name
  5164. # in the warnings. For missing keys, we should show the prefix in the warning as it's part of the final model
  5165. if loading_task_model_from_base_state_dict:
  5166. _prefix = f"{self.base_model_prefix}."
  5167. unexpected_keys = [k.removeprefix(_prefix) for k in unexpected_keys]
  5168. return missing_keys, unexpected_keys
  5169. def get_parameter_or_buffer(self, target: str):
  5170. """
  5171. Return the parameter or buffer given by `target` if it exists, otherwise throw an error. This combines
  5172. `get_parameter()` and `get_buffer()` in a single handy function. If the target is an `_extra_state` attribute,
  5173. it will return the extra state provided by the module. Note that it only work if `target` is a leaf of the model.
  5174. """
  5175. try:
  5176. return self.get_parameter(target)
  5177. except AttributeError:
  5178. pass
  5179. try:
  5180. return self.get_buffer(target)
  5181. except AttributeError:
  5182. pass
  5183. module, param_name = get_module_from_name(self, target)
  5184. if (
  5185. param_name == "_extra_state"
  5186. and getattr(module.__class__, "get_extra_state", torch.nn.Module.get_extra_state)
  5187. is not torch.nn.Module.get_extra_state
  5188. ):
  5189. return module.get_extra_state()
  5190. raise AttributeError(f"`{target}` is neither a parameter, buffer, nor extra state.")
  5191. def train(self, mode: bool = True):
  5192. out = super().train(mode)
  5193. if self.use_kernels:
  5194. self.kernelize()
  5195. return out
  5196. def eval(self):
  5197. return self.train(False)
  5198. PreTrainedModel.push_to_hub = copy_func(PreTrainedModel.push_to_hub)
  5199. if PreTrainedModel.push_to_hub.__doc__ is not None:
  5200. PreTrainedModel.push_to_hub.__doc__ = PreTrainedModel.push_to_hub.__doc__.format(
  5201. object="model", object_class="AutoModel", object_files="model file"
  5202. )
  5203. def unwrap_model(model: nn.Module, recursive: bool = False) -> nn.Module:
  5204. """
  5205. Recursively unwraps a model from potential containers (as used in distributed training).
  5206. Args:
  5207. model (`torch.nn.Module`): The model to unwrap.
  5208. recursive (`bool`, *optional*, defaults to `False`):
  5209. Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
  5210. recursively, not just the top-level distributed containers.
  5211. """
  5212. # Use accelerate implementation if available (should always be the case when using torch)
  5213. # This is for pytorch, as we also have to handle things like dynamo
  5214. if is_accelerate_available():
  5215. kwargs = {}
  5216. if recursive:
  5217. if not is_accelerate_available("0.29.0"):
  5218. raise RuntimeError(
  5219. "Setting `recursive=True` to `unwrap_model` requires `accelerate` v0.29.0. Please upgrade your version of accelerate"
  5220. )
  5221. else:
  5222. kwargs["recursive"] = recursive
  5223. return extract_model_from_parallel(model, **kwargs)
  5224. else:
  5225. # since there could be multiple levels of wrapping, unwrap recursively
  5226. if hasattr(model, "module"):
  5227. return unwrap_model(model.module)
  5228. else:
  5229. return model
  5230. def expand_device_map(device_map, param_names):
  5231. """
  5232. Expand a device map to return the correspondence parameter name to device.
  5233. """
  5234. new_device_map = {}
  5235. for module, device in device_map.items():
  5236. new_device_map.update(
  5237. {p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
  5238. )
  5239. return new_device_map
  5240. def is_accelerator_device(device: Union[str, int, torch.device]) -> bool:
  5241. """Check if the device is an accelerator. We need to function, as device_map can be "disk" as well, which is not
  5242. a proper `torch.device`.
  5243. """
  5244. if device == "disk":
  5245. return False
  5246. else:
  5247. return torch.device(device).type not in ["meta", "cpu"]
  5248. def caching_allocator_warmup(model: PreTrainedModel, expanded_device_map: dict, hf_quantizer: Optional[HfQuantizer]):
  5249. """This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
  5250. device. It allows to have one large call to Malloc, instead of recursively calling it later when loading
  5251. the model, which is actually the loading speed bottleneck.
  5252. Calling this function allows to cut the model loading time by a very large margin.
  5253. A few facts related to loading speed (taking into account the use of this function):
  5254. - When loading a model the first time, it is usually slower than the subsequent times, because the OS is very likely
  5255. to cache the different state dicts (if enough resources/RAM are available)
  5256. - Trying to force the OS to cache the files in advance (by e.g. accessing a small portion of them) is really hard,
  5257. and not a good idea in general as this is low level OS optimizations that depend on resource usage anyway
  5258. - As of 18/03/2025, loading a Llama 70B model with TP takes ~1 min without file cache, and ~13s with full file cache.
  5259. The baseline, i.e. only loading the tensor shards on device and adjusting dtype (i.e. copying them) is ~5s with full cache.
  5260. These numbers are reported for TP on 4 H100 GPUs.
  5261. - It is useless to pre-allocate more than the model size in this function (i.e. using an `allocation_factor` > 1) as
  5262. cudaMalloc is not a bottleneck at all anymore
  5263. - Loading speed bottleneck is now almost only tensor copy (i.e. changing the dtype) and moving the tensors to the devices.
  5264. However, we cannot really improve on those aspects obviously, as the data needs to be moved/copied in the end.
  5265. """
  5266. factor = 2 if hf_quantizer is None else hf_quantizer.get_accelerator_warm_up_factor()
  5267. # Remove disk, cpu and meta devices, and cast to proper torch.device
  5268. accelerator_device_map = {
  5269. param: torch.device(device) for param, device in expanded_device_map.items() if is_accelerator_device(device)
  5270. }
  5271. if not accelerator_device_map:
  5272. return
  5273. tp_plan = getattr(model, "_tp_plan", []) or []
  5274. tp_plan_regex = (
  5275. re.compile("|".join([re.escape(plan) for plan in tp_plan]))
  5276. if _torch_distributed_available and torch.distributed.is_initialized()
  5277. else None
  5278. )
  5279. total_byte_count = defaultdict(lambda: 0)
  5280. tied_param_names = _get_tied_weight_keys(model)
  5281. for param_name, device in accelerator_device_map.items():
  5282. # Skip if the parameter has already been accounted for (tied weights)
  5283. if param_name in tied_param_names:
  5284. continue
  5285. # For example in the case of MXFP4 quantization, we need to update the param name to the original param name
  5286. # because the checkpoint contains blocks, and scales, but since we are dequantizing, we need to use the original param name
  5287. if hf_quantizer is not None:
  5288. param_name = hf_quantizer.get_param_name(param_name)
  5289. try:
  5290. param = model.get_parameter_or_buffer(param_name)
  5291. except AttributeError:
  5292. raise AttributeError(f"Parameter {param_name} not found in model")
  5293. # The dtype of different parameters may be different with composite models or `keep_in_fp32_modules`
  5294. param_byte_count = param.numel() * param.element_size()
  5295. if tp_plan_regex is not None:
  5296. generic_name = re.sub(r"\.\d+\.", ".*.", param_name)
  5297. param_byte_count //= torch.distributed.get_world_size() if tp_plan_regex.search(generic_name) else 1
  5298. total_byte_count[device] += param_byte_count
  5299. # This will kick off the caching allocator to avoid having to Malloc afterwards
  5300. for device, byte_count in total_byte_count.items():
  5301. if device.type in ["cuda", "xpu"]:
  5302. torch_accelerator_module = getattr(torch, device.type)
  5303. index = device.index if device.index is not None else torch_accelerator_module.current_device()
  5304. device_memory = torch_accelerator_module.mem_get_info(index)[0]
  5305. # Allow up to (max device memory - 1.2 GiB) in resource-constrained hardware configurations. Trying to reserve more
  5306. # than that amount might sometimes lead to unnecessary cuda/xpu OOM, if the last parameter to be loaded on the device is large,
  5307. # and the remaining reserved memory portion is smaller than the param size -> torch will then try to fully re-allocate all
  5308. # the param size, instead of using the remaining reserved part, and allocating only the difference, which can lead
  5309. # to OOM. See https://github.com/huggingface/transformers/issues/37436#issuecomment-2808982161 for more details.
  5310. # Note that we use an absolute value instead of device proportion here, as a 8GiB device could still allocate too much
  5311. # if using e.g. 90% of device size, while a 140GiB device would allocate too little
  5312. byte_count = min(byte_count, max(0, int(device_memory - 1.2 * 1024**3)))
  5313. # If there is *unused* reserved cuda/xpu memory, we can skip/reduce the allocation.
  5314. unused_memory = torch_accelerator_module.memory_reserved(
  5315. index
  5316. ) - torch_accelerator_module.memory_allocated(index)
  5317. byte_count = max(0, byte_count - unused_memory)
  5318. # Allocate memory
  5319. _ = torch.empty(byte_count // factor, dtype=torch.float16, device=device, requires_grad=False)
  5320. def get_disk_only_shard_files(device_map, weight_map):
  5321. """
  5322. Returns the list of shard files containing only weights offloaded to disk.
  5323. """
  5324. files_content = collections.defaultdict(list)
  5325. for weight_name, filename in weight_map.items():
  5326. while len(weight_name) > 0 and weight_name not in device_map:
  5327. weight_name = ".".join(weight_name.split(".")[:-1])
  5328. files_content[filename].append(device_map[weight_name])
  5329. return [fname for fname, devices in files_content.items() if set(devices) == {"disk"}]
  5330. class AttentionInterface(GeneralInterface):
  5331. """
  5332. Dict-like object keeping track of allowed attention functions. You can easily add a new attention function
  5333. with a call to `register()`. If a model needs to locally overwrite an existing attention function, say `sdpa`,
  5334. it needs to declare a new instance of this class inside the `modeling_<model>.py`, and declare it on that instance.
  5335. """
  5336. # Class instance object, so that a call to `register` can be reflected into all other files correctly, even if
  5337. # a new instance is created (in order to locally override a given function)
  5338. _global_mapping = {
  5339. "flash_attention_3": flash_attention_forward,
  5340. "flash_attention_2": flash_attention_forward,
  5341. "flex_attention": flex_attention_forward,
  5342. "paged_attention": paged_attention_forward,
  5343. "sdpa": sdpa_attention_forward,
  5344. "sdpa_paged": sdpa_attention_paged_forward,
  5345. "eager_paged": eager_paged_attention_forward,
  5346. }
  5347. # Global AttentionInterface shared by all models which do not need to overwrite any of the existing ones
  5348. ALL_ATTENTION_FUNCTIONS: AttentionInterface = AttentionInterface()
  5349. class PreTrainedAudioTokenizerBase(PreTrainedModel):
  5350. """
  5351. Class that additionally defines the behavior of any `audio_tokenizer` to be added.
  5352. Characteristic for any of them:
  5353. 1. Encode raw audio into discrete audio codebooks (with x channels)
  5354. 2. Decode from discrete audio codebooks back to raw audio
  5355. It is possible that they can decode in different ways given a different representation
  5356. but they are forced to support 2. nonetheless, e.g. see `DAC`.
  5357. """
  5358. @abstractmethod
  5359. def encode(self, input_values: torch.Tensor, *args, **kwargs):
  5360. """
  5361. Encode raw audio retrieved from a respective `FeatureExtractor` into discrete audio codebooks (with x channels)
  5362. """
  5363. pass
  5364. @abstractmethod
  5365. def decode(self, audio_codes: torch.Tensor, *args, **kwargs):
  5366. """Decode from discrete audio codebooks back to raw audio"""
  5367. pass