accelerator.py 193 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985298629872988298929902991299229932994299529962997299829993000300130023003300430053006300730083009301030113012301330143015301630173018301930203021302230233024302530263027302830293030303130323033303430353036303730383039304030413042304330443045304630473048304930503051305230533054305530563057305830593060306130623063306430653066306730683069307030713072307330743075307630773078307930803081308230833084308530863087308830893090309130923093309430953096309730983099310031013102310331043105310631073108310931103111311231133114311531163117311831193120312131223123312431253126312731283129313031313132313331343135313631373138313931403141314231433144314531463147314831493150315131523153315431553156315731583159316031613162316331643165316631673168316931703171317231733174317531763177317831793180318131823183318431853186318731883189319031913192319331943195319631973198319932003201320232033204320532063207320832093210321132123213321432153216321732183219322032213222322332243225322632273228322932303231323232333234323532363237323832393240324132423243324432453246324732483249325032513252325332543255325632573258325932603261326232633264326532663267326832693270327132723273327432753276327732783279328032813282328332843285328632873288328932903291329232933294329532963297329832993300330133023303330433053306330733083309331033113312331333143315331633173318331933203321332233233324332533263327332833293330333133323333333433353336333733383339334033413342334333443345334633473348334933503351335233533354335533563357335833593360336133623363336433653366336733683369337033713372337333743375337633773378337933803381338233833384338533863387338833893390339133923393339433953396339733983399340034013402340334043405340634073408340934103411341234133414341534163417341834193420342134223423342434253426342734283429343034313432343334343435343634373438343934403441344234433444344534463447344834493450345134523453345434553456345734583459346034613462346334643465346634673468346934703471347234733474347534763477347834793480348134823483348434853486348734883489349034913492349334943495349634973498349935003501350235033504350535063507350835093510351135123513351435153516351735183519352035213522352335243525352635273528352935303531353235333534353535363537353835393540354135423543354435453546354735483549355035513552355335543555355635573558355935603561356235633564356535663567356835693570357135723573357435753576357735783579358035813582358335843585358635873588358935903591359235933594359535963597359835993600360136023603360436053606360736083609361036113612361336143615361636173618361936203621362236233624362536263627362836293630363136323633363436353636363736383639364036413642364336443645364636473648364936503651365236533654365536563657365836593660366136623663366436653666366736683669367036713672367336743675367636773678367936803681368236833684368536863687368836893690369136923693369436953696369736983699370037013702370337043705370637073708370937103711371237133714371537163717371837193720372137223723372437253726372737283729373037313732373337343735373637373738373937403741374237433744374537463747374837493750375137523753375437553756375737583759376037613762376337643765376637673768376937703771377237733774377537763777377837793780378137823783378437853786378737883789379037913792379337943795379637973798379938003801380238033804380538063807380838093810381138123813381438153816381738183819382038213822382338243825382638273828382938303831383238333834383538363837383838393840384138423843384438453846384738483849385038513852385338543855385638573858385938603861386238633864386538663867386838693870387138723873387438753876387738783879388038813882388338843885388638873888388938903891389238933894389538963897389838993900390139023903390439053906390739083909391039113912391339143915391639173918391939203921392239233924392539263927392839293930393139323933393439353936393739383939394039413942394339443945394639473948394939503951395239533954395539563957395839593960396139623963396439653966396739683969397039713972397339743975397639773978397939803981398239833984398539863987398839893990399139923993399439953996399739983999400040014002400340044005400640074008400940104011401240134014401540164017401840194020402140224023402440254026402740284029403040314032403340344035403640374038403940404041404240434044404540464047404840494050405140524053405440554056405740584059406040614062406340644065406640674068406940704071407240734074407540764077407840794080408140824083408440854086408740884089409040914092409340944095409640974098409941004101410241034104410541064107410841094110411141124113411441154116411741184119412041214122412341244125412641274128412941304131413241334134413541364137413841394140414141424143414441454146414741484149415041514152415341544155415641574158415941604161416241634164416541664167416841694170417141724173417441754176417741784179418041814182418341844185418641874188418941904191419241934194419541964197419841994200420142024203420442054206420742084209421042114212421342144215421642174218421942204221422242234224422542264227422842294230423142324233423442354236423742384239424042414242424342444245424642474248424942504251425242534254425542564257425842594260426142624263426442654266426742684269427042714272427342744275427642774278427942804281428242834284428542864287428842894290429142924293429442954296429742984299430043014302430343044305430643074308430943104311431243134314431543164317431843194320432143224323432443254326432743284329433043314332433343344335433643374338433943404341434243434344434543464347434843494350
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from __future__ import annotations
  15. import contextlib
  16. import functools
  17. import json
  18. import math
  19. import os
  20. import re
  21. import shutil
  22. import warnings
  23. from collections import OrderedDict
  24. from contextlib import contextmanager
  25. from functools import partial
  26. from types import MethodType
  27. from typing import Any, Callable, Union
  28. import torch
  29. import torch.utils.hooks as hooks
  30. from huggingface_hub import split_torch_state_dict_into_shards
  31. from accelerate.utils.dataclasses import FP8BackendType
  32. from .big_modeling import _attach_context_parallel_hooks
  33. from .checkpointing import load_accelerator_state, load_custom_state, save_accelerator_state, save_custom_state
  34. from .data_loader import DataLoaderDispatcher, prepare_data_loader, skip_first_batches
  35. from .logging import get_logger
  36. from .optimizer import AcceleratedOptimizer
  37. from .parallelism_config import ParallelismConfig
  38. from .scheduler import AcceleratedScheduler
  39. from .state import AcceleratorState, GradientState, PartialState
  40. from .tracking import LOGGER_TYPE_TO_CLASS, GeneralTracker, filter_trackers
  41. from .utils import (
  42. MODEL_NAME,
  43. SAFE_WEIGHTS_INDEX_NAME,
  44. SAFE_WEIGHTS_NAME,
  45. SAFE_WEIGHTS_PATTERN_NAME,
  46. WEIGHTS_INDEX_NAME,
  47. WEIGHTS_NAME,
  48. WEIGHTS_PATTERN_NAME,
  49. AORecipeKwargs,
  50. AutocastKwargs,
  51. DataLoaderConfiguration,
  52. DeepSpeedPlugin,
  53. DistributedDataParallelKwargs,
  54. DistributedType,
  55. DynamoBackend,
  56. FP8RecipeKwargs,
  57. FullyShardedDataParallelPlugin,
  58. GradientAccumulationPlugin,
  59. GradScalerKwargs,
  60. InitProcessGroupKwargs,
  61. KwargsHandler,
  62. LoggerType,
  63. MegatronLMPlugin,
  64. MSAMPRecipeKwargs,
  65. PrecisionType,
  66. ProfileKwargs,
  67. ProjectConfiguration,
  68. RNGType,
  69. TERecipeKwargs,
  70. TorchDynamoPlugin,
  71. TorchTensorParallelPlugin,
  72. apply_fp8_autowrap,
  73. check_os_kernel,
  74. clean_state_dict_for_safetensors,
  75. compare_versions,
  76. convert_model,
  77. convert_model_to_fp8_ao,
  78. convert_outputs_to_fp32,
  79. ensure_weights_retied,
  80. extract_model_from_parallel,
  81. fsdp2_apply_ac,
  82. fsdp2_canonicalize_names,
  83. fsdp2_prepare_model,
  84. fsdp2_switch_optimizer_parameters,
  85. gather,
  86. gather_object,
  87. get_fsdp2_grad_scaler,
  88. get_grad_scaler,
  89. get_mixed_precision_context_manager,
  90. get_pretty_name,
  91. has_offloaded_params,
  92. is_bf16_available,
  93. is_bitsandbytes_multi_backend_available,
  94. is_deepspeed_available,
  95. is_ipex_available,
  96. is_lomo_available,
  97. is_megatron_lm_available,
  98. is_mlu_available,
  99. is_msamp_available,
  100. is_musa_available,
  101. is_npu_available,
  102. is_torch_version,
  103. is_torch_xla_available,
  104. is_torchao_available,
  105. is_transformer_engine_available,
  106. is_xpu_available,
  107. load_fsdp_model,
  108. load_fsdp_optimizer,
  109. model_has_dtensor,
  110. pad_across_processes,
  111. parse_choice_from_env,
  112. recursively_apply,
  113. reduce,
  114. release_memory,
  115. save,
  116. save_fsdp_model,
  117. save_fsdp_optimizer,
  118. wait_for_everyone,
  119. )
  120. from .utils.constants import (
  121. FSDP2_PYTORCH_VERSION,
  122. FSDP_PYTORCH_VERSION,
  123. PROFILE_PATTERN_NAME,
  124. SCALER_NAME,
  125. )
  126. from .utils.modeling import get_state_dict_offloaded_model
  127. from .utils.other import compile_regions, compile_regions_deepspeed, is_compiled_module
  128. if is_deepspeed_available():
  129. from .utils import (
  130. DeepSpeedEngineWrapper,
  131. DeepSpeedOptimizerWrapper,
  132. DeepSpeedSchedulerWrapper,
  133. DummyOptim,
  134. DummyScheduler,
  135. map_pytorch_optim_to_deepspeed,
  136. )
  137. if is_megatron_lm_available():
  138. from .utils import (
  139. MegatronEngine,
  140. MegatronLMDummyDataLoader,
  141. MegatronLMDummyScheduler,
  142. MegatronLMOptimizerWrapper,
  143. MegatronLMSchedulerWrapper,
  144. megatron_lm_initialize,
  145. megatron_lm_prepare_data_loader,
  146. megatron_lm_prepare_model_optimizer_scheduler,
  147. )
  148. from torch.distributed.algorithms.join import Join
  149. if is_torch_xla_available():
  150. import torch_xla.core.xla_model as xm
  151. import torch_xla.distributed.xla_multiprocessing as xmp
  152. if is_npu_available(check_device=False):
  153. import torch_npu # noqa: F401
  154. try:
  155. from torch.optim.lr_scheduler import LRScheduler
  156. except ImportError:
  157. from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
  158. logger = get_logger(__name__)
  159. # Sentinel values for defaults
  160. _split_batches = object()
  161. _dispatch_batches = object()
  162. _even_batches = object()
  163. _use_seedable_sampler = object()
  164. class Accelerator:
  165. """
  166. Creates an instance of an accelerator for distributed training or mixed precision training.
  167. Args:
  168. device_placement (`bool`, *optional*, defaults to `True`):
  169. Whether or not the accelerator should put objects on device (tensors yielded by the dataloader, model,
  170. etc...).
  171. mixed_precision (`str`, *optional*):
  172. Whether or not to use mixed precision training. Choose from 'no','fp16','bf16' or 'fp8'. Will default to
  173. the value in the environment variable `ACCELERATE_MIXED_PRECISION`, which will use the default value in the
  174. accelerate config of the current system or the flag passed with the `accelerate.launch` command. 'fp8'
  175. requires the installation of transformers-engine.
  176. gradient_accumulation_steps (`int`, *optional*, default to 1):
  177. The number of steps that should pass before gradients are accumulated. A number > 1 should be combined with
  178. `Accelerator.accumulate`. If not passed, will default to the value in the environment variable
  179. `ACCELERATE_GRADIENT_ACCUMULATION_STEPS`. Can also be configured through a `GradientAccumulationPlugin`.
  180. cpu (`bool`, *optional*):
  181. Whether or not to force the script to execute on CPU. Will ignore GPU available if set to `True` and force
  182. the execution on one process only.
  183. dataloader_config (`DataLoaderConfiguration`, *optional*):
  184. A configuration for how the dataloaders should be handled in distributed scenarios.
  185. deepspeed_plugin ([`~utils.DeepSpeedPlugin`] or dict of `str`: [`~utils.DeepSpeedPlugin`], *optional*):
  186. Tweak your DeepSpeed related args using this argument. This argument is optional and can be configured
  187. directly using *accelerate config*. If using multiple plugins, use the configured `key` property of each
  188. plugin to access them from `accelerator.state.get_deepspeed_plugin(key)`. Alias for `deepspeed_plugins`.
  189. fsdp_plugin ([`~utils.FullyShardedDataParallelPlugin`], *optional*):
  190. Tweak your FSDP related args using this argument. This argument is optional and can be configured directly
  191. using *accelerate config*
  192. torch_tp_plugin ([`~utils.TorchTensorParallelPlugin`], *optional*):
  193. Deprecated: use `parallelism_config` with `tp_size` instead.
  194. megatron_lm_plugin ([`~utils.MegatronLMPlugin`], *optional*):
  195. Tweak your MegatronLM related args using this argument. This argument is optional and can be configured
  196. directly using *accelerate config*
  197. rng_types (list of `str` or [`~utils.RNGType`]):
  198. The list of random number generators to synchronize at the beginning of each iteration in your prepared
  199. dataloaders. Should be one or several of:
  200. - `"torch"`: the base torch random number generator
  201. - `"cuda"`: the CUDA random number generator (GPU only)
  202. - `"xla"`: the XLA random number generator (TPU only)
  203. - `"generator"`: the `torch.Generator` of the sampler (or batch sampler if there is no sampler in your
  204. dataloader) or of the iterable dataset (if it exists) if the underlying dataset is of that type.
  205. Will default to `["torch"]` for PyTorch versions <=1.5.1 and `["generator"]` for PyTorch versions >= 1.6.
  206. log_with (list of `str`, [`~utils.LoggerType`] or [`~tracking.GeneralTracker`], *optional*):
  207. A list of loggers to be setup for experiment tracking. Should be one or several of:
  208. - `"all"`
  209. - `"tensorboard"`
  210. - `"wandb"`
  211. - `"trackio"`
  212. - `"aim"`
  213. - `"comet_ml"`
  214. - `"mlflow"`
  215. - `"dvclive"`
  216. - `"swanlab"`
  217. If `"all"` is selected, will pick up all available trackers in the environment and initialize them. Can
  218. also accept implementations of `GeneralTracker` for custom trackers, and can be combined with `"all"`.
  219. project_config ([`~utils.ProjectConfiguration`], *optional*):
  220. A configuration for how saving the state can be handled.
  221. project_dir (`str`, `os.PathLike`, *optional*):
  222. A path to a directory for storing data such as logs of locally-compatible loggers and potentially saved
  223. checkpoints.
  224. step_scheduler_with_optimizer (`bool`, *optional*, defaults to `True`):
  225. Set `True` if the learning rate scheduler is stepped at the same time as the optimizer, `False` if only
  226. done under certain circumstances (at the end of each epoch, for instance).
  227. kwargs_handlers (list of [`~utils.KwargsHandler`], *optional*)
  228. A list of [`~utils.KwargsHandler`] to customize how the objects related to distributed training, profiling
  229. or mixed precision are created. See [kwargs](kwargs) for more information.
  230. dynamo_backend (`str` or [`~utils.DynamoBackend`], *optional*, defaults to `"no"`):
  231. Set to one of the possible dynamo backends to optimize your training with torch dynamo.
  232. dynamo_plugin ([`~utils.TorchDynamoPlugin`], *optional*):
  233. A configuration for how torch dynamo should be handled, if more tweaking than just the `backend` or `mode`
  234. is needed.
  235. gradient_accumulation_plugin ([`~utils.GradientAccumulationPlugin`], *optional*):
  236. A configuration for how gradient accumulation should be handled, if more tweaking than just the
  237. `gradient_accumulation_steps` is needed.
  238. **Available attributes:**
  239. - **device** (`torch.device`) -- The device to use.
  240. - **distributed_type** ([`~utils.DistributedType`]) -- The distributed training configuration.
  241. - **local_process_index** (`int`) -- The process index on the current machine.
  242. - **mixed_precision** (`str`) -- The configured mixed precision mode.
  243. - **num_processes** (`int`) -- The total number of processes used for training.
  244. - **optimizer_step_was_skipped** (`bool`) -- Whether or not the optimizer update was skipped (because of
  245. gradient overflow in mixed precision), in which
  246. case the learning rate should not be changed.
  247. - **process_index** (`int`) -- The overall index of the current process among all processes.
  248. - **state** ([`~state.AcceleratorState`]) -- The distributed setup state.
  249. - **sync_gradients** (`bool`) -- Whether the gradients are currently being synced across all processes.
  250. - **use_distributed** (`bool`) -- Whether the current configuration is for distributed training.
  251. """
  252. def __init__(
  253. self,
  254. device_placement: bool = True,
  255. split_batches: bool = _split_batches,
  256. mixed_precision: PrecisionType | str | None = None,
  257. gradient_accumulation_steps: int = 1,
  258. cpu: bool = False,
  259. dataloader_config: DataLoaderConfiguration | None = None,
  260. deepspeed_plugin: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
  261. fsdp_plugin: FullyShardedDataParallelPlugin | None = None,
  262. torch_tp_plugin: TorchTensorParallelPlugin | None = None, # Deprecate later, warning in `post_init`
  263. megatron_lm_plugin: MegatronLMPlugin | None = None,
  264. rng_types: list[str | RNGType] | None = None,
  265. log_with: str | LoggerType | GeneralTracker | list[str | LoggerType | GeneralTracker] | None = None,
  266. project_dir: str | os.PathLike | None = None,
  267. project_config: ProjectConfiguration | None = None,
  268. gradient_accumulation_plugin: GradientAccumulationPlugin | None = None,
  269. step_scheduler_with_optimizer: bool = True,
  270. kwargs_handlers: list[KwargsHandler] | None = None,
  271. dynamo_backend: DynamoBackend | str | None = None,
  272. dynamo_plugin: TorchDynamoPlugin | None = None,
  273. deepspeed_plugins: DeepSpeedPlugin | dict[str, DeepSpeedPlugin] | None = None,
  274. parallelism_config: ParallelismConfig | None = None,
  275. ):
  276. self.trackers = []
  277. if project_config is not None:
  278. self.project_configuration = project_config
  279. else:
  280. self.project_configuration = ProjectConfiguration(project_dir=project_dir)
  281. if project_dir is not None and self.project_dir is None:
  282. self.project_configuration.set_directories(project_dir)
  283. if mixed_precision is not None:
  284. mixed_precision = str(mixed_precision)
  285. if mixed_precision not in PrecisionType:
  286. raise ValueError(
  287. f"Unknown mixed_precision mode: {mixed_precision}. Choose between {PrecisionType.list()}"
  288. )
  289. if torch_tp_plugin is not None:
  290. warnings.warn(
  291. "`TorchTensorParallelPlugin` is deprecated and will be removed in a future version of Accelerate. "
  292. "Please use the `ParallelismConfig` with `tp_size` instead.",
  293. FutureWarning,
  294. )
  295. if dynamo_plugin is not None and dynamo_backend is not None:
  296. raise ValueError("You cannot pass in both `dynamo_plugin` and `dynamo_backend`, please only pass in one.")
  297. if dynamo_backend is not None:
  298. dynamo_plugin = TorchDynamoPlugin(backend=dynamo_backend)
  299. elif dynamo_plugin is None:
  300. dynamo_plugin = TorchDynamoPlugin()
  301. if deepspeed_plugins is not None and deepspeed_plugin is not None:
  302. raise ValueError("You cannot pass in both `deepspeed_plugins` and `deepspeed_plugin`.")
  303. elif deepspeed_plugin is not None:
  304. deepspeed_plugins = deepspeed_plugin
  305. if deepspeed_plugins is None:
  306. # First check if we're creating another `Accelerator` w/o setting `deepspeed_plugin`
  307. if (
  308. AcceleratorState._shared_state != {}
  309. and AcceleratorState().distributed_type == DistributedType.DEEPSPEED
  310. ):
  311. deepspeed_plugins = AcceleratorState().deepspeed_plugins
  312. else:
  313. # init from env variables
  314. deepspeed_plugins = (
  315. DeepSpeedPlugin()
  316. if os.environ.get("ACCELERATE_USE_DEEPSPEED", "false").lower() == "true"
  317. else None
  318. )
  319. else:
  320. # If we're creating a second `Accelerator`, users shouldn't be passing in a `deepspeed_plugin`
  321. if (
  322. AcceleratorState._shared_state != {}
  323. and AcceleratorState().distributed_type == DistributedType.DEEPSPEED
  324. and AcceleratorState().deepspeed_plugins is not None
  325. ):
  326. raise NotImplementedError(
  327. "You cannot pass in a `deepspeed_plugin` when creating a second `Accelerator`. "
  328. "Please make sure the first `Accelerator` is initialized with all the plugins you want to use."
  329. )
  330. if isinstance(deepspeed_plugins, dict):
  331. for plugin in deepspeed_plugins.values():
  332. if not isinstance(plugin, DeepSpeedPlugin):
  333. raise TypeError("`deepspeed_plugin` must be a DeepSpeedPlugin object.")
  334. if deepspeed_plugins is not None:
  335. os.environ["ACCELERATE_USE_DEEPSPEED"] = "true" # use DeepSpeed if plugin is provided
  336. if not is_deepspeed_available():
  337. raise ImportError("DeepSpeed is not installed => run `pip install deepspeed` or build it from source.")
  338. if is_mlu_available():
  339. if compare_versions("deepspeed", "<", "0.15.2"):
  340. raise ImportError("DeepSpeed MLU version must be >= 0.15.2. Please update DeepSpeed.")
  341. elif is_musa_available():
  342. if compare_versions("deepspeed", "<", "0.14.3"):
  343. raise ImportError("DeepSpeed MUSA version must be >= 0.14.3. Please update DeepSpeed.")
  344. elif compare_versions("deepspeed", "<", "0.9.3"):
  345. raise ImportError("DeepSpeed version must be >= 0.9.3. Please update DeepSpeed.")
  346. self.deepspeed_engine_wrapped = None
  347. if os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" or isinstance(
  348. fsdp_plugin, FullyShardedDataParallelPlugin
  349. ):
  350. if not is_torch_version(">=", FSDP_PYTORCH_VERSION):
  351. raise ValueError(f"FSDP requires PyTorch >= {FSDP_PYTORCH_VERSION}")
  352. if fsdp_plugin is None: # init from env variables
  353. fsdp_plugin = (
  354. FullyShardedDataParallelPlugin()
  355. if os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
  356. else None
  357. )
  358. else:
  359. if not isinstance(fsdp_plugin, FullyShardedDataParallelPlugin):
  360. raise TypeError("`fsdp_plugin` must be a FullyShardedDataParallelPlugin object.")
  361. os.environ["ACCELERATE_USE_FSDP"] = "true" # use FSDP if plugin is provided
  362. if fsdp_plugin is not None and fsdp_plugin.fsdp_version == 2:
  363. if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
  364. raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
  365. if megatron_lm_plugin is None: # init from env variables
  366. megatron_lm_plugin = (
  367. MegatronLMPlugin() if os.environ.get("ACCELERATE_USE_MEGATRON_LM", "false").lower() == "true" else None
  368. )
  369. else:
  370. if not isinstance(megatron_lm_plugin, MegatronLMPlugin):
  371. raise TypeError("`megatron_lm_plugin` must be a MegatronLMPlugin object.")
  372. os.environ["ACCELERATE_USE_MEGATRON_LM"] = "true" # use MegatronLM if plugin is provided
  373. if megatron_lm_plugin:
  374. if not is_megatron_lm_available():
  375. raise ImportError("Megatron is not installed. please build it from source.")
  376. # Kwargs handlers
  377. self.ddp_handler = None
  378. self.scaler_handler = None
  379. self.init_handler = None
  380. self.fp8_recipe_handler = None
  381. self.ao_recipe_handler = None
  382. self.te_recipe_handler = None
  383. self.msamp_recipe_handler = None
  384. self.autocast_handler = None
  385. self.profile_handler = None
  386. self.has_lomo_optimizer = False
  387. found_handlers = set()
  388. handler_class_to_attr = {
  389. DistributedDataParallelKwargs: "ddp_handler",
  390. GradScalerKwargs: "scaler_handler",
  391. InitProcessGroupKwargs: "init_handler",
  392. FP8RecipeKwargs: "fp8_recipe_handler",
  393. AutocastKwargs: "autocast_handler",
  394. ProfileKwargs: "profile_handler",
  395. AORecipeKwargs: "ao_recipe_handler",
  396. TERecipeKwargs: "te_recipe_handler",
  397. MSAMPRecipeKwargs: "msamp_recipe_handler",
  398. }
  399. self.has_fp8_handler = False
  400. if kwargs_handlers is not None:
  401. for handler in kwargs_handlers:
  402. assert isinstance(handler, KwargsHandler), (
  403. f"Unsupported kwargs handler passed: {handler}, must be one that inherits `accelerate.utils.KwargsHandler`."
  404. )
  405. # Add the handler class to the set of found handlers
  406. if handler.__class__ in found_handlers:
  407. raise ValueError(f"You can only pass one {handler.__class__} in `kwargs_handlers`.")
  408. found_handlers.add(handler.__class__)
  409. handler_attr = handler_class_to_attr[handler.__class__]
  410. setattr(self, handler_attr, handler)
  411. if "recipe_handler" in handler_attr and not self.has_fp8_handler:
  412. self.has_fp8_handler = True
  413. if parallelism_config is None:
  414. # TODO: Remove after deprecating tp_plugin
  415. if torch_tp_plugin is not None:
  416. parallelism_config = ParallelismConfig(tp_size=torch_tp_plugin.tp_size)
  417. elif os.environ.get("ACCELERATE_USE_PARALLELISM_CONFIG", "false").lower() == "true":
  418. parallelism_config = ParallelismConfig()
  419. kwargs = self.init_handler.to_kwargs() if self.init_handler is not None else {}
  420. self.state = AcceleratorState(
  421. mixed_precision=mixed_precision,
  422. cpu=cpu,
  423. dynamo_plugin=dynamo_plugin,
  424. deepspeed_plugin=deepspeed_plugins,
  425. fsdp_plugin=fsdp_plugin,
  426. megatron_lm_plugin=megatron_lm_plugin,
  427. parallelism_config=parallelism_config,
  428. _from_accelerator=True,
  429. **kwargs,
  430. )
  431. if self.parallelism_config:
  432. self.state.device_mesh = parallelism_config.get_device_mesh(self.device.type)
  433. self.parallelism_config._validate_accelerator(self)
  434. self.fp8_enabled = self.state.mixed_precision == "fp8" or mixed_precision == "fp8"
  435. # Check for automatic FP8 recipe creation
  436. if self.fp8_enabled and not self.has_fp8_handler:
  437. if self.fp8_backend == FP8BackendType.AO:
  438. self.ao_recipe_handler = AORecipeKwargs()
  439. elif self.fp8_backend == FP8BackendType.TE:
  440. self.te_recipe_handler = TERecipeKwargs()
  441. elif self.fp8_backend == FP8BackendType.MSAMP:
  442. self.msamp_recipe_handler = MSAMPRecipeKwargs()
  443. elif self.fp8_backend == FP8BackendType.NO:
  444. # Prioritize AO -> TE -> MSAMP
  445. if is_torchao_available():
  446. logger.info("Found `torchao` installed, using it for FP8 training.")
  447. self.ao_recipe_handler = AORecipeKwargs()
  448. elif is_transformer_engine_available():
  449. logger.info("Found `transformer-engine` installed, using it for FP8 training.")
  450. self.te_recipe_handler = TERecipeKwargs()
  451. elif is_msamp_available():
  452. logger.info("Found `msamp` installed, using it for FP8 training.")
  453. self.msamp_recipe_handler = MSAMPRecipeKwargs()
  454. else:
  455. raise ImportError(
  456. "Tried to train with `fp8` and auto-detect backend, but no FP8-compatible backend was installed. "
  457. "Valid backends are: `torchao`, `transformer-engine`, and `msamp`."
  458. )
  459. self.has_fp8_handler = True
  460. self.delayed_fp8_autocast = False
  461. if self.has_fp8_handler:
  462. # We already check if FP8 is available during `self.state`
  463. if not self.fp8_enabled and (
  464. self.distributed_type not in (DistributedType.FSDP, DistributedType.DEEPSPEED)
  465. ):
  466. raise ValueError("Passing in an FP8 configuration requires setting `mixed_precision='fp8'`.")
  467. self.delayed_fp8_autocast = self.fp8_backend == "TE" and self.distributed_type in (
  468. DistributedType.MULTI_GPU,
  469. DistributedType.FSDP,
  470. )
  471. # TODO: S1ro - this is probably gonna be a problem with other fp8 backends too
  472. if (
  473. self.fp8_backend == FP8BackendType.AO
  474. and self.state.distributed_type == DistributedType.FSDP
  475. and self.state.fsdp_plugin.cpu_ram_efficient_loading
  476. ):
  477. raise ValueError(
  478. "torchao with FSDP2 and cpu_ram_efficient_loading is not supported, setting `cpu_ram_efficient_loading` to False will fix the issue and work as intended."
  479. )
  480. trackers = filter_trackers(log_with, self.logging_dir)
  481. if len(trackers) < 1 and log_with is not None:
  482. warnings.warn(f"`log_with={log_with}` was passed but no supported trackers are currently installed.")
  483. self.log_with = trackers
  484. if (
  485. (mixed_precision != "bf16")
  486. and getattr(self.state, "downcast_bfloat", False)
  487. and (self.state.distributedType != DistributedType.XLA)
  488. ):
  489. raise ValueError("Can only use `downcast_bf16` when using `mixed_precision='bf16'` and on a TPU")
  490. if gradient_accumulation_plugin is not None:
  491. if gradient_accumulation_steps != 1:
  492. raise ValueError(
  493. "You can only pass one of `gradient_accumulation_steps` and `gradient_accumulation_plugin`. Please only pass in the created `GradientAccumulationPlugin` object."
  494. )
  495. else:
  496. gradient_accumulation_steps = int(
  497. parse_choice_from_env("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", gradient_accumulation_steps)
  498. )
  499. gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=gradient_accumulation_steps)
  500. # If using DeepSpeed, update gradient accumulation steps from the DeepSpeed plugin
  501. self.gradient_state = GradientState(
  502. gradient_accumulation_plugin=gradient_accumulation_plugin,
  503. )
  504. self.device_placement = device_placement
  505. if dataloader_config is None:
  506. dataloader_config = DataLoaderConfiguration()
  507. self.dataloader_config = dataloader_config
  508. self.step_scheduler_with_optimizer = step_scheduler_with_optimizer
  509. # Mixed precision attributes
  510. self.scaler = None
  511. self.native_amp = False
  512. if (
  513. self.state.mixed_precision == "fp16"
  514. and self.device.type != "cpu"
  515. and self.distributed_type not in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM)
  516. ):
  517. self.native_amp = True
  518. supported_device = ("xpu", "cuda", "npu", "xla", "mlu", "musa", "hpu", "sdaa", "mps")
  519. if self.device.type not in supported_device or is_torch_xla_available(check_is_tpu=True):
  520. raise ValueError(
  521. f"fp16 mixed precision requires a device in {supported_device} (not {self.device.type!r})."
  522. )
  523. if self.device.type == "mps" and not is_torch_version(">=", "2.5.0"):
  524. raise ValueError("fp16 mixed precision with MPS device requires a Pytorch >= 2.5.0")
  525. kwargs = self.scaler_handler.to_kwargs() if self.scaler_handler is not None else {}
  526. # FSDP2 doesn't use ShardedGradScaler, don't want to modify `get_grad_scaler`, rather create a simple utility
  527. if self.is_fsdp2:
  528. self.scaler = get_fsdp2_grad_scaler(device=self.device.type, **kwargs)
  529. else:
  530. self.scaler = get_grad_scaler(self.distributed_type, **kwargs)
  531. elif self.state.mixed_precision == "bf16" and self.distributed_type not in (
  532. DistributedType.DEEPSPEED,
  533. DistributedType.MEGATRON_LM,
  534. ):
  535. if self.device.type in ["cpu", "xpu", "hpu"]:
  536. self.native_amp = True
  537. else:
  538. self.native_amp = is_bf16_available(True)
  539. if not self.native_amp and not is_torch_xla_available():
  540. raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
  541. if self.native_amp and self.device.type == "mps" and not is_torch_version(">=", "2.6.0"):
  542. raise ValueError("bf16 mixed precision with MPS device requires a Pytorch >= 2.6.0")
  543. # for DeepSpeed, self.state.mixed_precision is always "bf16",
  544. # see https://github.com/huggingface/accelerate/blob/main/src/accelerate/state.py#L968 and
  545. # https://github.com/huggingface/accelerate/blob/main/src/accelerate/utils/dataclasses.py#L1263.
  546. elif self.fp8_enabled:
  547. # We always enable `native_amp` for FP8
  548. self.native_amp = True
  549. if self.fp8_backend == FP8BackendType.MSAMP:
  550. if self.distributed_type == DistributedType.FSDP:
  551. raise NotImplementedError(
  552. "`accelerate` + `MS-AMP` + `FSDP` is not supported at this time. "
  553. "Please consider using deepspeed, which is supported."
  554. )
  555. elif self.distributed_type != DistributedType.DEEPSPEED:
  556. # MS-AMP requires `GradScaler` even with bf16 autocast w/ single GPU or DDP:
  557. self.scaler = get_grad_scaler(**kwargs)
  558. # Start of internal step tracking
  559. self.step = 0
  560. # Internal references to the training objects
  561. self._optimizers = []
  562. self._models = []
  563. self._schedulers = []
  564. self._dataloaders = []
  565. self._custom_objects = []
  566. # Hooks
  567. self._load_model_state_pre_hook = OrderedDict()
  568. self._save_model_state_pre_hook = OrderedDict()
  569. # RNG Types
  570. self.rng_types = rng_types
  571. if self.rng_types is None:
  572. self.rng_types = ["generator"]
  573. # Set a flag tensor for early stopping and other breakpoints
  574. self.flag_tensor = None
  575. check_os_kernel()
  576. @property
  577. def deepspeed_plugin(self):
  578. """
  579. Returns the currently active DeepSpeedPlugin.
  580. If using multiple plugins, the first one will be the active one by default. Manually call
  581. `accelerator.state.select_deepspeed_plugin(key)` to activate a different plugin.
  582. If deepspeed is not enabled, this will return `None`.
  583. """
  584. return self.state.deepspeed_plugin
  585. @property
  586. def use_distributed(self):
  587. """
  588. Whether the Accelerator is configured for distributed training
  589. """
  590. return self.state.use_distributed
  591. @property
  592. def multi_device(self):
  593. return self.use_distributed and self.distributed_type in (
  594. DistributedType.MULTI_GPU,
  595. DistributedType.MULTI_MLU,
  596. DistributedType.MULTI_SDAA,
  597. DistributedType.MULTI_MUSA,
  598. DistributedType.MULTI_NPU,
  599. DistributedType.MULTI_XPU,
  600. DistributedType.MULTI_HPU,
  601. )
  602. @property
  603. def distributed_type(self):
  604. return self.state.distributed_type
  605. @property
  606. def num_processes(self):
  607. return self.state.num_processes
  608. @property
  609. def process_index(self):
  610. return self.state.process_index
  611. @property
  612. def local_process_index(self):
  613. return self.state.local_process_index
  614. @property
  615. def device(self):
  616. return self.state.device
  617. @property
  618. def split_batches(self):
  619. return self.dataloader_config.split_batches
  620. @property
  621. def dispatch_batches(self):
  622. return self.dataloader_config.dispatch_batches
  623. @property
  624. def even_batches(self):
  625. return self.dataloader_config.even_batches
  626. @even_batches.setter
  627. def even_batches(self, value: bool):
  628. self.dataloader_config.even_batches = value
  629. @property
  630. def use_seedable_sampler(self):
  631. return self.dataloader_config.use_seedable_sampler
  632. @property
  633. def non_blocking(self):
  634. return self.dataloader_config.non_blocking
  635. @property
  636. def use_stateful_dataloader(self):
  637. if hasattr(self.dataloader_config, "use_stateful_dataloader"):
  638. return self.dataloader_config.use_stateful_dataloader
  639. return False
  640. @property
  641. def project_dir(self):
  642. return self.project_configuration.project_dir
  643. @property
  644. def logging_dir(self):
  645. return self.project_configuration.logging_dir
  646. @property
  647. def save_iteration(self):
  648. return self.project_configuration.iteration
  649. @property
  650. def is_main_process(self):
  651. """True for one process only."""
  652. return self.state.is_main_process
  653. @property
  654. def is_local_main_process(self):
  655. """True for one process per server."""
  656. return self.state.is_local_main_process
  657. @property
  658. def is_last_process(self):
  659. return self.process_index == self.num_processes - 1
  660. @property
  661. def mixed_precision(self):
  662. return self.state.mixed_precision
  663. @property
  664. def is_fsdp2(self):
  665. return self.state.is_fsdp2
  666. @property
  667. def is_composable_parallelism_enabled(self):
  668. return self.is_fsdp2
  669. @property
  670. def parallelism_config(self) -> Union[ParallelismConfig, None]:
  671. return self.state.parallelism_config
  672. @property
  673. def torch_device_mesh(self):
  674. return self.state.device_mesh
  675. @property
  676. def should_save_model(self):
  677. if (pc := self.parallelism_config) is None:
  678. # shouldn't even happen
  679. return self.state.is_local_main_process
  680. _non_model_shard_dims = {
  681. pc.dp_replicate_enabled: "dp_replicate",
  682. pc.cp_enabled: "cp",
  683. }
  684. # return all(
  685. # self.torch_device_mesh[dim].get_local_rank() == 0 for key, dim in non_model_shard_dims.items() if key
  686. # )
  687. # TODO: S1ro - this is a temporary solution until we figure out why `save_safe_file` is slow when not all processes
  688. return True
  689. @property
  690. def tensor_parallel_rank(self) -> int:
  691. """
  692. Returns the local rank for tensor parallelism. If tensor parallelism is configured but not enabled, returns 0
  693. since all ranks are assumed to be the same.
  694. """
  695. if self.parallelism_config:
  696. if self.parallelism_config.tp_enabled:
  697. return self.torch_device_mesh.get_local_rank("tp")
  698. return 0
  699. raise RuntimeError("Tensor parallelism is not configured. Set `parallelism_config` first.")
  700. @property
  701. def pipeline_parallel_rank(self) -> int:
  702. """
  703. Pipeline parallelism is not supported yet.
  704. """
  705. raise NotImplementedError("Pipeline parallelism is currently not supported in Accelerate.")
  706. @property
  707. def context_parallel_rank(self) -> int:
  708. """
  709. Context parallelism is not supported yet.
  710. """
  711. raise NotImplementedError("Context parallelism is currently not supported in Accelerate.")
  712. @property
  713. def data_parallel_rank(self) -> int:
  714. """
  715. Returns the local rank for replicate-based data parallelism. If replicate-based data parallelism is configured
  716. but not enabled, returns 0 since all ranks are assumed to be the same.
  717. """
  718. if self.parallelism_config:
  719. if self.parallelism_config.dp_replicate_enabled:
  720. return self.torch_device_mesh.get_local_rank("dp_replicate")
  721. return 0
  722. raise RuntimeError("Data parallelism is not configured. Set `parallelism_config` first.")
  723. @property
  724. def data_parallel_shard_rank(self) -> int:
  725. """
  726. Returns the local rank for shard-based data parallelism. If shard-based data parallelism is configured but not
  727. enabled, returns 0 since all ranks are assumed to be the same.
  728. """
  729. if self.parallelism_config:
  730. if self.parallelism_config.dp_shard_enabled:
  731. return self.torch_device_mesh.get_local_rank("dp_shard")
  732. return 0
  733. raise RuntimeError("Shard-based data parallelism is not configured. Set `parallelism_config` first.")
  734. @contextmanager
  735. def split_between_processes(self, inputs: list | tuple | dict | torch.Tensor, apply_padding: bool = False):
  736. """
  737. Splits `input` between `self.num_processes` quickly and can be then used on that process. Useful when doing
  738. distributed inference, such as with different prompts.
  739. Note that when using a `dict`, all keys need to have the same number of elements.
  740. Args:
  741. inputs (`list`, `tuple`, `torch.Tensor`, or `dict` of `list`/`tuple`/`torch.Tensor`):
  742. The input to split between processes.
  743. apply_padding (`bool`, `optional`, defaults to `False`):
  744. Whether to apply padding by repeating the last element of the input so that all processes have the same
  745. number of elements. Useful when trying to perform actions such as `Accelerator.gather()` on the outputs
  746. or passing in less inputs than there are processes. If so, just remember to drop the padded elements
  747. afterwards.
  748. Example:
  749. ```python
  750. # Assume there are two processes
  751. from accelerate import Accelerator
  752. accelerator = Accelerator()
  753. with accelerator.split_between_processes(["A", "B", "C"]) as inputs:
  754. print(inputs)
  755. # Process 0
  756. ["A", "B"]
  757. # Process 1
  758. ["C"]
  759. with accelerator.split_between_processes(["A", "B", "C"], apply_padding=True) as inputs:
  760. print(inputs)
  761. # Process 0
  762. ["A", "B"]
  763. # Process 1
  764. ["C", "C"]
  765. ```
  766. """
  767. with PartialState().split_between_processes(inputs, apply_padding=apply_padding) as inputs:
  768. yield inputs
  769. def on_main_process(self, function: Callable[..., Any] | None = None):
  770. """
  771. A decorator that will run the decorated function on the main process only. Can also be called using the
  772. `PartialState` class.
  773. Args:
  774. function (`Callable`): The function to decorate.
  775. Example:
  776. ```python
  777. >>> from accelerate import Accelerator
  778. >>> accelerator = Accelerator()
  779. >>> @accelerator.on_main_process
  780. ... def print_something():
  781. ... print("This will be printed by process 0 only.")
  782. >>> print_something()
  783. "This will be printed by process 0 only"
  784. ```
  785. """
  786. # For times when the `Accelerator` object itself utilizes this decorator.
  787. if function is None:
  788. if "Accelerator." in self.__qualname__:
  789. function = self
  790. else:
  791. raise ValueError(
  792. "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
  793. )
  794. def _inner(*args, **kwargs):
  795. return PartialState().on_main_process(function)(*args, **kwargs)
  796. return _inner
  797. def on_local_main_process(self, function: Callable[..., Any] | None = None):
  798. """
  799. A decorator that will run the decorated function on the local main process only. Can also be called using the
  800. `PartialState` class.
  801. Args:
  802. function (`Callable`): The function to decorate.
  803. Example:
  804. ```python
  805. # Assume we have 2 servers with 4 processes each.
  806. from accelerate import Accelerator
  807. accelerator = Accelerator()
  808. @accelerator.on_local_main_process
  809. def print_something():
  810. print("This will be printed by process 0 only on each server.")
  811. print_something()
  812. # On server 1:
  813. "This will be printed by process 0 only"
  814. # On server 2:
  815. "This will be printed by process 0 only"
  816. ```
  817. """
  818. # For times when the `Accelerator` object itself utilizes this decorator.
  819. if function is None:
  820. if "Accelerator." in self.__qualname__:
  821. function = self
  822. else:
  823. raise ValueError(
  824. "The `on_local_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
  825. )
  826. def _inner(*args, **kwargs):
  827. return PartialState().on_local_main_process(function)(*args, **kwargs)
  828. return _inner
  829. def on_last_process(self, function: Callable[..., Any]):
  830. """
  831. A decorator that will run the decorated function on the last process only. Can also be called using the
  832. `PartialState` class.
  833. Args:
  834. function (`Callable`): The function to decorate.
  835. Example:
  836. ```python
  837. # Assume we have 4 processes.
  838. from accelerate import Accelerator
  839. accelerator = Accelerator()
  840. @accelerator.on_last_process
  841. def print_something():
  842. print(f"Printed on process {accelerator.process_index}")
  843. print_something()
  844. "Printed on process 3"
  845. ```
  846. """
  847. # For times when the `Accelerator` object itself utilizes this decorator.
  848. if function is None:
  849. if "Accelerator." in self.__qualname__:
  850. function = self
  851. else:
  852. raise ValueError(
  853. "The `on_last_process` decorator must be called with a function on an instantiated `Accelerator` object."
  854. )
  855. def _inner(*args, **kwargs):
  856. return PartialState().on_last_process(function)(*args, **kwargs)
  857. return _inner
  858. def on_process(self, function: Callable[..., Any] | None = None, process_index: int | None = None):
  859. """
  860. A decorator that will run the decorated function on a given process index only. Can also be called using the
  861. `PartialState` class.
  862. Args:
  863. function (`Callable`, `optional`):
  864. The function to decorate.
  865. process_index (`int`, `optional`):
  866. The index of the process on which to run the function.
  867. Example:
  868. ```python
  869. # Assume we have 4 processes.
  870. from accelerate import Accelerator
  871. accelerator = Accelerator()
  872. @accelerator.on_process(process_index=2)
  873. def print_something():
  874. print(f"Printed on process {accelerator.process_index}")
  875. print_something()
  876. "Printed on process 2"
  877. ```
  878. """
  879. # Initial construction of the decorator.
  880. if (self is not None) and (process_index is not None) and (function is None):
  881. return partial(self.on_process, process_index=process_index)
  882. # For times when the `Accelerator` object itself utilizes this decorator.
  883. if function is None:
  884. if "Accelerator." in self.__qualname__:
  885. function = self
  886. else:
  887. raise ValueError(
  888. "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
  889. )
  890. def _inner(*args, **kwargs):
  891. return PartialState().on_process(function, process_index)(*args, **kwargs)
  892. return _inner
  893. def on_local_process(self, function: Callable[..., Any] | None = None, local_process_index: int | None = None):
  894. """
  895. A decorator that will run the decorated function on a given local process index only. Can also be called using
  896. the `PartialState` class.
  897. Args:
  898. function (`Callable`, *optional*):
  899. The function to decorate.
  900. local_process_index (`int`, *optional*):
  901. The index of the local process on which to run the function.
  902. Example:
  903. ```python
  904. # Assume we have 2 servers with 4 processes each.
  905. from accelerate import Accelerator
  906. accelerator = Accelerator()
  907. @accelerator.on_local_process(local_process_index=2)
  908. def print_something():
  909. print(f"Printed on process {accelerator.local_process_index}")
  910. print_something()
  911. # On server 1:
  912. "Printed on process 2"
  913. # On server 2:
  914. "Printed on process 2"
  915. ```
  916. """
  917. # Initial construction of the decorator.
  918. if (self is not None) and (local_process_index is not None) and (function is None):
  919. return partial(self.on_local_process, local_process_index=local_process_index)
  920. # For times when the `Accelerator` object itself utilizes this decorator.
  921. if function is None:
  922. if "Accelerator." in self.__qualname__:
  923. function = self
  924. else:
  925. raise ValueError(
  926. "The `on_main_process` decorator must be called with a function on an instantiated `Accelerator` object."
  927. )
  928. def _inner(*args, **kwargs):
  929. return PartialState().on_local_process(function, local_process_index)(*args, **kwargs)
  930. return _inner
  931. @contextmanager
  932. def main_process_first(self):
  933. """
  934. Lets the main process go first inside a with block.
  935. The other processes will enter the with block after the main process exits.
  936. Example:
  937. ```python
  938. >>> from accelerate import Accelerator
  939. >>> accelerator = Accelerator()
  940. >>> with accelerator.main_process_first():
  941. ... # This will be printed first by process 0 then in a seemingly
  942. ... # random order by the other processes.
  943. ... print(f"This will be printed by process {accelerator.process_index}")
  944. ```
  945. """
  946. with self.state.main_process_first():
  947. yield
  948. @contextmanager
  949. def local_main_process_first(self):
  950. """
  951. Lets the local main process go inside a with block.
  952. The other processes will enter the with block after the main process exits.
  953. Example:
  954. ```python
  955. >>> from accelerate import Accelerator
  956. >>> accelerator = Accelerator()
  957. >>> with accelerator.local_main_process_first():
  958. ... # This will be printed first by local process 0 then in a seemingly
  959. ... # random order by the other processes.
  960. ... print(f"This will be printed by process {accelerator.local_process_index}")
  961. ```
  962. """
  963. with self.state.local_main_process_first():
  964. yield
  965. @contextmanager
  966. def no_sync(self, model):
  967. """
  968. A context manager to disable gradient synchronizations across DDP processes by calling
  969. `torch.nn.parallel.DistributedDataParallel.no_sync`.
  970. If `model` is not in DDP, this context manager does nothing
  971. Args:
  972. model (`torch.nn.Module`):
  973. PyTorch Module that was prepared with `Accelerator.prepare`
  974. Example:
  975. ```python
  976. >>> from accelerate import Accelerator
  977. >>> accelerator = Accelerator()
  978. >>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer)
  979. >>> input_a = next(iter(dataloader))
  980. >>> input_b = next(iter(dataloader))
  981. >>> with accelerator.no_sync():
  982. ... outputs = model(input_a)
  983. ... loss = loss_func(outputs)
  984. ... accelerator.backward(loss)
  985. ... # No synchronization across processes, only accumulate gradients
  986. >>> outputs = model(input_b)
  987. >>> accelerator.backward(loss)
  988. >>> # Synchronization across all processes
  989. >>> optimizer.step()
  990. >>> optimizer.zero_grad()
  991. ```
  992. """
  993. if self.is_fsdp2:
  994. model.set_requires_gradient_sync(False)
  995. try:
  996. yield
  997. finally:
  998. model.set_requires_gradient_sync(True)
  999. else:
  1000. context = contextlib.nullcontext
  1001. if self.use_distributed:
  1002. if self.distributed_type != DistributedType.DEEPSPEED or self.state.deepspeed_plugin.zero_stage < 2:
  1003. context = getattr(model, "no_sync", context)
  1004. with context():
  1005. yield
  1006. @staticmethod
  1007. @contextmanager
  1008. def trigger_sync_in_backward(model):
  1009. """Trigger the sync of the gradients in the next backward pass of the model after multiple forward passes under
  1010. `Accelerator.no_sync` (only applicable in multi-GPU scenarios).
  1011. If the script is not launched in distributed mode, this context manager does nothing.
  1012. Args:
  1013. model (`torch.nn.Module`):
  1014. The model for which to trigger the gradient synchronization.
  1015. Example:
  1016. ```python
  1017. >>> from accelerate import Accelerator
  1018. >>> accelerator = Accelerator()
  1019. >>> dataloader, model, optimizer = accelerator.prepare(dataloader, model, optimizer)
  1020. >>> with accelerator.no_sync():
  1021. ... loss_a = loss_func(model(input_a)) # first forward pass
  1022. ... loss_b = loss_func(model(input_b)) # second forward pass
  1023. >>> accelerator.backward(loss_a) # No synchronization across processes, only accumulate gradients
  1024. >>> with accelerator.trigger_sync_in_backward(model):
  1025. ... accelerator.backward(loss_b) # Synchronization across all processes
  1026. >>> optimizer.step()
  1027. >>> optimizer.zero_grad()
  1028. ```
  1029. """
  1030. if not isinstance(model, torch.nn.parallel.DistributedDataParallel):
  1031. yield
  1032. return
  1033. old_require_backward_grad_sync = model.require_backward_grad_sync
  1034. old_require_forward_param_sync = model.require_forward_param_sync
  1035. # EXPERIMENTAL: This will force grad sync during `backward()`, but it is unknown if it breaks other DDP features.
  1036. # https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/torch/nn/parallel/distributed.py#L1453-L1466
  1037. model.require_backward_grad_sync = True
  1038. model.require_forward_param_sync = True
  1039. # https://github.com/pytorch/pytorch/blob/e1502c0cdbfd17548c612f25d5a65b1e4b86224d/torch/csrc/distributed/c10d/reducer.cpp#L1371-L1402
  1040. model.reducer.prepare_for_backward([])
  1041. try:
  1042. yield
  1043. finally:
  1044. model.require_backward_grad_sync = old_require_backward_grad_sync
  1045. model.require_forward_param_sync = old_require_forward_param_sync
  1046. def _do_sync(self):
  1047. "Sets the right `sync_gradients` context and either resets or increases `self.step`"
  1048. if self.gradient_state.sync_with_dataloader and self.gradient_state.end_of_dataloader:
  1049. self.step = 0
  1050. self.gradient_state._set_sync_gradients(True)
  1051. else:
  1052. self.step += 1
  1053. self.gradient_state._set_sync_gradients((self.step % self.gradient_state.num_steps) == 0)
  1054. @property
  1055. def sync_gradients(self):
  1056. return self.gradient_state.sync_gradients
  1057. @sync_gradients.setter
  1058. def sync_gradients(self, sync_gradients):
  1059. self.gradient_state.sync_gradients = sync_gradients
  1060. @property
  1061. def gradient_accumulation_steps(self):
  1062. return self.gradient_state.num_steps
  1063. @gradient_accumulation_steps.setter
  1064. def gradient_accumulation_steps(self, gradient_accumulation_steps):
  1065. self.gradient_state.plugin_kwargs.update({"num_steps": gradient_accumulation_steps})
  1066. @contextmanager
  1067. def accumulate(self, *models):
  1068. """
  1069. A context manager that will lightly wrap around and perform gradient accumulation automatically
  1070. Args:
  1071. *models (list of `torch.nn.Module`):
  1072. PyTorch Modules that were prepared with `Accelerator.prepare`. Models passed to `accumulate()` will
  1073. skip gradient syncing during backward pass in distributed training
  1074. Example:
  1075. ```python
  1076. >>> from accelerate import Accelerator
  1077. >>> accelerator = Accelerator(gradient_accumulation_steps=1)
  1078. >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
  1079. >>> for input, output in dataloader:
  1080. ... with accelerator.accumulate(model):
  1081. ... outputs = model(input)
  1082. ... loss = loss_func(outputs)
  1083. ... loss.backward()
  1084. ... optimizer.step()
  1085. ... scheduler.step()
  1086. ... optimizer.zero_grad()
  1087. ```
  1088. """
  1089. self._do_sync()
  1090. allow_gradient_sync = (
  1091. self.sync_gradients # must sync if sync gradients need to complete an optimizer step
  1092. or (
  1093. # the no_sync context stops the gradients from reducing during distributed training
  1094. # bringing speedup (potentially at some costs). Here, no_sync can be prevented
  1095. # by setting sync_each_batch = True.
  1096. self.use_distributed # only relevant in distributed settings
  1097. and self.gradient_state.plugin_kwargs.get("sync_each_batch", False)
  1098. )
  1099. )
  1100. with contextlib.ExitStack() as cm_stack:
  1101. for m in models:
  1102. cm_stack.enter_context(contextlib.nullcontext() if allow_gradient_sync else self.no_sync(m))
  1103. yield
  1104. @contextmanager
  1105. def join_uneven_inputs(self, joinables, even_batches=None):
  1106. """
  1107. A context manager that facilitates distributed training or evaluation on uneven inputs, which acts as a wrapper
  1108. around `torch.distributed.algorithms.join`. This is useful when the total batch size does not evenly divide the
  1109. length of the dataset.
  1110. Args:
  1111. joinables (`list[torch.distributed.algorithms.Joinable]`):
  1112. A list of models or optimizers that subclass `torch.distributed.algorithms.Joinable`. Most commonly, a
  1113. PyTorch Module that was prepared with `Accelerator.prepare` for DistributedDataParallel training.
  1114. even_batches (`bool`, *optional*)
  1115. If set, this will override the value of `even_batches` set in the `Accelerator`. If it is not provided,
  1116. the default `Accelerator` value wil be used.
  1117. <Tip warning={true}>
  1118. `join_uneven_inputs` is only supported for Distributed Data Parallel training on multiple GPUs. For any other
  1119. configuration, this method will have no effect.
  1120. </Tip>
  1121. <Tip warning={true}>
  1122. Overriding `even_batches` will not affect iterable-style data loaders.
  1123. </Tip>
  1124. Example:
  1125. ```python
  1126. >>> from accelerate import Accelerator
  1127. >>> accelerator = Accelerator(even_batches=True)
  1128. >>> ddp_model, optimizer, dataloader = accelerator.prepare(model, optimizer, dataloader)
  1129. >>> with accelerator.join_uneven_inputs([ddp_model], even_batches=False):
  1130. ... for input, output in dataloader:
  1131. ... outputs = model(input)
  1132. ... loss = loss_func(outputs)
  1133. ... loss.backward()
  1134. ... optimizer.step()
  1135. ... optimizer.zero_grad()
  1136. ```
  1137. """
  1138. if self.multi_device:
  1139. dl_even_batches_values = []
  1140. if even_batches is not None:
  1141. iterable_dl_seen = False
  1142. # override value in batch sampler for map-style datasets
  1143. for dl_idx, dl in enumerate(self._dataloaders):
  1144. if isinstance(dl, DataLoaderDispatcher):
  1145. iterable_dl_seen = True
  1146. continue
  1147. dl_even_batches_values.append((dl_idx, dl.batch_sampler.even_batches))
  1148. dl.batch_sampler.even_batches = even_batches
  1149. if iterable_dl_seen:
  1150. warnings.warn(
  1151. "Overriding even_batches is only supported for map-style datasets, yet some dataloaders given were iterable"
  1152. )
  1153. else:
  1154. even_batches = self.even_batches
  1155. enable_join = False if even_batches else True
  1156. try:
  1157. with Join(joinables, enable=enable_join, throw_on_early_termination=False):
  1158. yield
  1159. finally:
  1160. # reset any batch samplers that have been modified
  1161. for dl_idx, even_batches_value in dl_even_batches_values:
  1162. self._dataloaders[dl_idx].batch_sampler.even_batches = even_batches_value
  1163. else:
  1164. # Even when disabled, Join expects models to subclass Joinable, so skip entirely for single process runs
  1165. if self.distributed_type != DistributedType.NO:
  1166. warnings.warn(
  1167. "Joining uneven inputs is only supported for multi-GPU training, as a result `join_uneven_inputs` will have no effect."
  1168. )
  1169. with contextlib.nullcontext(joinables):
  1170. yield
  1171. def print(self, *args, **kwargs):
  1172. """
  1173. Drop in replacement of `print()` to only print once per server.
  1174. Example:
  1175. ```python
  1176. >>> from accelerate import Accelerator
  1177. >>> accelerator = Accelerator()
  1178. >>> accelerator.print("Hello world!")
  1179. ```
  1180. """
  1181. self.state.print(*args, **kwargs)
  1182. def _prepare_one(self, obj, first_pass=False, device_placement=None):
  1183. # First pass of preparation: DataLoader, model, optimizer
  1184. if first_pass:
  1185. if isinstance(obj, torch.utils.data.DataLoader):
  1186. return self.prepare_data_loader(obj, device_placement=device_placement)
  1187. elif isinstance(obj, torch.nn.Module):
  1188. return self.prepare_model(obj, device_placement=device_placement)
  1189. elif isinstance(obj, torch.optim.Optimizer):
  1190. optimizer = self.prepare_optimizer(obj, device_placement=device_placement)
  1191. return optimizer
  1192. # Second pass of preparation: LR scheduler (which need the full list of optimizers)
  1193. elif isinstance(obj, LRScheduler):
  1194. scheduler = self.prepare_scheduler(obj)
  1195. return scheduler
  1196. # Return the unprocessed object if previous criteria was not met
  1197. return obj
  1198. def prepare(self, *args, device_placement=None):
  1199. """
  1200. Prepare all objects passed in `args` for distributed training and mixed precision, then return them in the same
  1201. order.
  1202. Args:
  1203. *args (list of objects):
  1204. Any of the following type of objects:
  1205. - `torch.utils.data.DataLoader`: PyTorch Dataloader
  1206. - `torch.nn.Module`: PyTorch Module
  1207. - `torch.optim.Optimizer`: PyTorch Optimizer
  1208. - `torch.optim.lr_scheduler.LRScheduler`: PyTorch LR Scheduler
  1209. device_placement (`list[bool]`, *optional*):
  1210. Used to customize whether automatic device placement should be performed for each object passed. Needs
  1211. to be a list of the same length as `args`. Not compatible with DeepSpeed or FSDP.
  1212. <Tip>
  1213. You don't need to prepare a model if you only use it for inference without any kind of mixed precision
  1214. </Tip>
  1215. Examples:
  1216. ```python
  1217. >>> from accelerate import Accelerator
  1218. >>> accelerator = Accelerator()
  1219. >>> # Assume a model, optimizer, data_loader and scheduler are defined
  1220. >>> model, optimizer, data_loader, scheduler = accelerator.prepare(model, optimizer, data_loader, scheduler)
  1221. ```
  1222. ```python
  1223. >>> from accelerate import Accelerator
  1224. >>> accelerator = Accelerator()
  1225. >>> # Assume a model, optimizer, data_loader and scheduler are defined
  1226. >>> device_placement = [True, True, False, False]
  1227. >>> # Will place the first two items passed in automatically to the right device but not the last two.
  1228. >>> model, optimizer, data_loader, scheduler = accelerator.prepare(
  1229. ... model, optimizer, data_loader, scheduler, device_placement=device_placement
  1230. ... )
  1231. ```
  1232. """
  1233. if device_placement is None:
  1234. device_placement = [None for _ in args]
  1235. elif self.distributed_type in (DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM):
  1236. raise ValueError("You can't customize device placements with DeepSpeed or Megatron-LM.")
  1237. elif len(device_placement) != len(args):
  1238. raise ValueError(
  1239. f"`device_placement` should be a list with {len(args)} elements (the number of objects passed)."
  1240. )
  1241. for obj in args:
  1242. # TODO: Look at enabling native TP training directly with a proper config
  1243. if (
  1244. isinstance(obj, torch.nn.Module)
  1245. and self.verify_device_map(obj)
  1246. and self.distributed_type != DistributedType.NO
  1247. and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true"
  1248. ):
  1249. raise ValueError(
  1250. "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode."
  1251. " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
  1252. )
  1253. if self.distributed_type == DistributedType.DEEPSPEED:
  1254. model_count = 0
  1255. for obj in args:
  1256. if isinstance(obj, torch.nn.Module):
  1257. model_count += 1
  1258. if model_count > 1:
  1259. raise AssertionError(
  1260. "You can't use same `Accelerator()` instance with multiple models when using DeepSpeed"
  1261. )
  1262. # On TPUs, putting the model on the XLA device will create new parameters, so the corresponding optimizer will
  1263. # have parameters disconnected from the model (so no training :-( ).
  1264. # If the model and optimizer have parameters on different devices we raise an error.
  1265. if self.distributed_type == DistributedType.XLA:
  1266. model_device, optimizer_device = self._get_devices()
  1267. if model_device is not None and optimizer_device is not None and model_device != optimizer_device:
  1268. raise ValueError(
  1269. "The model and the optimizer parameters are not on the same device, which probably means you "
  1270. "created an optimizer around your model **before** putting on the device. Make sure the line "
  1271. "model.to(device) is before the optimizer creation in your script or remove it entirely and use "
  1272. "the flag default value for `device_placement` in your `Accelerator` to let it handle that "
  1273. "part for you."
  1274. )
  1275. if self.is_fsdp2:
  1276. model_count = 0
  1277. optimizer_count = 0
  1278. for i, obj in enumerate(args):
  1279. if isinstance(obj, torch.nn.Module):
  1280. model_count += 1
  1281. elif isinstance(obj, torch.optim.Optimizer):
  1282. optimizer_count += 1
  1283. # This needs to be written as such, so that passing other objects other than models/optimizers doesn't raise an error
  1284. if (model_count < 1 and optimizer_count > 0) or (model_count > 0 and optimizer_count < 1):
  1285. raise ValueError(
  1286. "When using FSDP2, a model and optimizer must be passed together to `Accelerator.prepare()`"
  1287. " as the optimizer needs to have its parameters modified after the model is converted."
  1288. )
  1289. if model_count > 1:
  1290. raise ValueError("Only one model is supported when using FSDP2")
  1291. # If we're dealing with device placement, this deals with that by...
  1292. tpu_should_fix_optimizer = self.device_placement and self.distributed_type == DistributedType.XLA
  1293. if tpu_should_fix_optimizer:
  1294. # 1. grabbing old model parameters
  1295. old_named_params = self._get_named_parameters(*args, drop_refs=False)
  1296. if self.distributed_type in [DistributedType.MULTI_CPU, DistributedType.MULTI_XPU, DistributedType.NO]:
  1297. if (
  1298. is_torch_version("<", "2.7.0")
  1299. and (self.device.type == "cpu" or self.device.type == "xpu")
  1300. and self.state.use_ipex
  1301. ):
  1302. logger.warning(
  1303. "You are using lower version of PyTorch(< 2.7.0) with ipex acceleration on Intel CPU or XPU, Intel has upstreamed most of the optimizations into stock PyTorch from 2.7.0, we encourage you to install the latest stock PyTorch and enjoy the out-of-experience on Intel CPU/XPU."
  1304. )
  1305. args = self._prepare_ipex(*args)
  1306. if self.parallelism_config and self.parallelism_config.tp_enabled:
  1307. args = self._prepare_tp(*args)
  1308. if self.parallelism_config and self.parallelism_config.cp_enabled:
  1309. args = self._prepare_cp(*args)
  1310. if self.fp8_backend == FP8BackendType.TE:
  1311. args = self._prepare_te(*args)
  1312. elif self.fp8_backend == FP8BackendType.AO:
  1313. args = self._prepare_ao(*args)
  1314. if self.distributed_type == DistributedType.DEEPSPEED:
  1315. result = self._prepare_deepspeed(*args)
  1316. elif self.distributed_type == DistributedType.MEGATRON_LM:
  1317. result = self._prepare_megatron_lm(*args)
  1318. elif self.is_fsdp2:
  1319. result = self._prepare_fsdp2(*args)
  1320. else:
  1321. if self.fp8_backend == FP8BackendType.MSAMP:
  1322. args, device_placement = self._prepare_msamp(*args, device_placement=device_placement)
  1323. result = tuple(
  1324. self._prepare_one(obj, first_pass=True, device_placement=d) for obj, d in zip(args, device_placement)
  1325. )
  1326. result = tuple(self._prepare_one(obj, device_placement=d) for obj, d in zip(result, device_placement))
  1327. if tpu_should_fix_optimizer:
  1328. # 2. grabbing new model parameters
  1329. new_named_params = self._get_named_parameters(*result)
  1330. # 3. building a map from the first to the second
  1331. mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
  1332. # 4. using that map to update the parameters of the optimizer
  1333. for obj in result:
  1334. if isinstance(obj, torch.optim.Optimizer):
  1335. obj._switch_parameters(mapping)
  1336. for item in result:
  1337. if any(
  1338. item in container
  1339. for container in (self._dataloaders, self._models, self._optimizers, self._schedulers)
  1340. ):
  1341. item._is_accelerate_prepared = True
  1342. return result if len(result) > 1 else result[0]
  1343. def _prepare_tp(self, *args):
  1344. # First pass: prepare everything except schedulers (and model, which is prepared separately below)
  1345. result = [
  1346. self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args
  1347. ]
  1348. # Second pass: prepare schedulers
  1349. result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]
  1350. device_mesh = self.torch_device_mesh
  1351. old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))
  1352. for arg in result:
  1353. if not isinstance(arg, torch.nn.Module):
  1354. continue
  1355. from torch.distributed.tensor import DTensor, Replicate
  1356. from transformers.integrations.tensor_parallel import ReplicateParallel
  1357. model: torch.nn.Module = arg
  1358. tp_plan = ReplicateParallel
  1359. for name, param in model.named_parameters():
  1360. if isinstance(param, DTensor):
  1361. continue
  1362. dp = DTensor.from_local(param, device_mesh=device_mesh["tp"], placements=[Replicate()])
  1363. param_name, param_type = name.rsplit(".", 1)
  1364. module_to_tp = model.get_submodule(param_name)
  1365. tp_plan().prepare_module_tp(module_to_tp, device_mesh["tp"])
  1366. if not isinstance(dp, torch.nn.Parameter):
  1367. dp = torch.nn.Parameter(dp, requires_grad=param.requires_grad)
  1368. setattr(module_to_tp, param_type, dp)
  1369. new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=False))
  1370. # Build a map from old to new params
  1371. mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
  1372. def _get_tensor_address(p):
  1373. if isinstance(p, DTensor):
  1374. return p._local_tensor.data_ptr()
  1375. return p.data_ptr()
  1376. for obj in result:
  1377. if isinstance(obj, torch.optim.Optimizer):
  1378. for param_group in obj.param_groups:
  1379. # Each param_group originally maps to model parameters (e.g., from model.parameters()).
  1380. # After _prepare_tp(), parameter references are replaced with DTensor instances.
  1381. # Therefore, we remap the parameter references to their new DTensor addresses
  1382. # so that the optimizer can correctly update the model parameters.
  1383. param_group["params"] = [mapping[_get_tensor_address(p)] for p in param_group["params"]]
  1384. return args
  1385. def _prepare_cp(self, *args):
  1386. if self.parallelism_config.sp_backend == "deepspeed":
  1387. # deepspeed handles cp in a different way, configured in _prepare_deepspeed
  1388. return args
  1389. from torch.distributed.tensor.experimental import context_parallel
  1390. from torch.distributed.tensor.experimental._attention import set_rotate_method
  1391. cp_comm_strategy = self.parallelism_config.cp_handler.cp_comm_strategy
  1392. set_rotate_method(cp_comm_strategy)
  1393. self._cp_context = functools.partial(context_parallel, mesh=self.torch_device_mesh["cp"])
  1394. for arg in args:
  1395. if isinstance(arg, torch.nn.Module):
  1396. _attach_context_parallel_hooks(arg)
  1397. return args
  1398. def _prepare_fsdp2(self, *args):
  1399. # First pass: prepare everything except schedulers (and model, which is prepared separately below)
  1400. result = [
  1401. self._prepare_one(obj, first_pass=True) if not isinstance(obj, torch.nn.Module) else obj for obj in args
  1402. ]
  1403. # Second pass: prepare schedulers
  1404. result = [self._prepare_one(obj) if not isinstance(obj, torch.nn.Module) else obj for obj in result]
  1405. # Prepare the model
  1406. model_index, model = None, None
  1407. for i, obj in enumerate(result):
  1408. if isinstance(obj, torch.nn.Module):
  1409. model_index, model = i, obj
  1410. # Invariant: if we have a model, we also have an optimizer (checked in `prepare`)
  1411. if model_index is None:
  1412. return tuple(result)
  1413. # Needs to be done first, to make sure AC + fully_shard will work as expected
  1414. self.state.fsdp_plugin.set_auto_wrap_policy(model)
  1415. # Apply AC if needed
  1416. if self.state.fsdp_plugin.activation_checkpointing:
  1417. model = fsdp2_apply_ac(self, model)
  1418. # Apply compile if needed, has to be *after* applying AC
  1419. # Copied from: `accelerator.prepare_model` ~ L1804
  1420. if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
  1421. if self.state.dynamo_plugin.use_regional_compilation:
  1422. model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
  1423. else:
  1424. model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
  1425. # Get old params and canonicalize - we canonicalize to have the mapping easy
  1426. old_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*tuple(result), drop_refs=True))
  1427. # Swap the optimizer parameters with empty, so `fully_shard` after will not allocate too much memory
  1428. from torch.distributed.tensor import DTensor
  1429. for obj in result:
  1430. if isinstance(obj, torch.optim.Optimizer):
  1431. for param_group in obj.param_groups:
  1432. for i, p in enumerate(param_group["params"]):
  1433. # We drop a reference to the original param here, so that _move_states_to_device triggers a reallocation
  1434. # We reassign the data_ptr to the original param, so that we preserve the mapping to the new ones
  1435. param_group["params"][i] = torch.empty(1, dtype=p.dtype, device=p.device)
  1436. param_group["params"][i].data_ptr = (
  1437. p._local_tensor.data_ptr() if isinstance(p, DTensor) else p.data_ptr()
  1438. )
  1439. self._models.append(model)
  1440. # Prepare everything FSDP2 related for the model (except AC)
  1441. model = fsdp2_prepare_model(self, model)
  1442. # Remove the old model from the list
  1443. if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
  1444. del self._models[-2]
  1445. # Replace the old model with the new one (shouldn't be needed as everything should be in place)
  1446. result[model_index] = model
  1447. # Get new params and canonicalize
  1448. new_named_params = fsdp2_canonicalize_names(self._get_named_parameters(*result))
  1449. # Build a map from old to new params
  1450. mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
  1451. # Update the optimizer parameters
  1452. for obj in result:
  1453. if isinstance(obj, torch.optim.Optimizer):
  1454. fsdp2_switch_optimizer_parameters(obj, mapping)
  1455. return result
  1456. def prepare_model(
  1457. self, model: torch.nn.Module, device_placement: bool | None = None, evaluation_mode: bool = False
  1458. ):
  1459. """
  1460. Prepares a PyTorch model for training in any distributed setup. It is recommended to use
  1461. [`Accelerator.prepare`] instead.
  1462. Args:
  1463. model (`torch.nn.Module`):
  1464. A PyTorch model to prepare. You don't need to prepare a model if it is used only for inference without
  1465. any kind of mixed precision
  1466. device_placement (`bool`, *optional*):
  1467. Whether or not to place the model on the proper device. Will default to `self.device_placement`.
  1468. evaluation_mode (`bool`, *optional*, defaults to `False`):
  1469. Whether or not to set the model for evaluation only, by just applying mixed precision and
  1470. `torch.compile` (if configured in the `Accelerator` object).
  1471. Example:
  1472. ```python
  1473. >>> from accelerate import Accelerator
  1474. >>> accelerator = Accelerator()
  1475. >>> # Assume a model is defined
  1476. >>> model = accelerator.prepare_model(model)
  1477. ```
  1478. """
  1479. if device_placement is None:
  1480. device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
  1481. self._models.append(model)
  1482. # TODO: Look at enabling native TP training directly with a proper config
  1483. if (
  1484. self.verify_device_map(model)
  1485. and self.distributed_type != DistributedType.NO
  1486. and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true"
  1487. ):
  1488. raise ValueError(
  1489. "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode."
  1490. " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
  1491. )
  1492. if self.native_amp:
  1493. model._original_forward = model.forward
  1494. autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
  1495. # NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
  1496. if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"):
  1497. model_forward_func = model.forward
  1498. model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
  1499. else:
  1500. model_forward_func = model.forward.__func__
  1501. new_forward = autocast_context(model_forward_func)
  1502. model.forward = MethodType(new_forward, model)
  1503. model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
  1504. # We prepare TE after, allowing for bf16 autocast to happen first
  1505. if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast:
  1506. model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
  1507. if (getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)) and getattr(
  1508. model, "hf_device_map", False
  1509. ):
  1510. model_devices = set(model.hf_device_map.values())
  1511. if len(model_devices) > 1 and self.distributed_type != DistributedType.NO:
  1512. raise ValueError(
  1513. "You can't train a model that has been loaded in 8-bit or 4-bit precision on multiple devices in any distributed mode."
  1514. " In order to use 8-bit or 4-bit models that have been loaded across multiple GPUs the solution is to use Naive Pipeline Parallelism."
  1515. " Therefore you should not specify that you are under any distributed regime in your accelerate config."
  1516. )
  1517. elif len(model_devices) == 1:
  1518. current_device = list(model_devices)[0]
  1519. if isinstance(current_device, torch.device):
  1520. current_device_index = current_device.index
  1521. elif isinstance(current_device, str):
  1522. current_device_index = torch.device(current_device).index
  1523. else:
  1524. current_device_index = current_device
  1525. if self.device.type == "cpu" and is_bitsandbytes_multi_backend_available():
  1526. # bnb with multi-backend supports CPU which don't need to check index.
  1527. pass
  1528. elif torch.device(current_device_index) != self.device:
  1529. # if on the first device (GPU 0) we don't care
  1530. if (self.device.index is not None) or (current_device_index != 0):
  1531. raise ValueError(
  1532. "You can't train a model that has been loaded in 8-bit or 4-bit precision on a different device than the one "
  1533. "you're training on. Make sure you loaded the model on the correct device using for example `device_map={'':torch.cuda.current_device()}` or `device_map={'':torch.xpu.current_device()}`"
  1534. )
  1535. if (
  1536. ("cpu" in model_devices and not is_bitsandbytes_multi_backend_available())
  1537. or ("cpu" in model_devices and is_xpu_available())
  1538. or "disk" in model_devices
  1539. ):
  1540. raise ValueError(
  1541. "You can't train a model that has been loaded in 8-bit or 4-bit precision with CPU or disk offload. "
  1542. "If you want train the 8-bit or 4-bit model in CPU, please install bitsandbytes with multi-backend, see https://huggingface.co/docs/bitsandbytes/main/en/installation#multi-backend"
  1543. )
  1544. elif device_placement and not self.verify_device_map(model):
  1545. model = model.to(self.device)
  1546. if not evaluation_mode:
  1547. if self.multi_device and not (self.parallelism_config and self.parallelism_config.tp_enabled):
  1548. if model_has_dtensor(model):
  1549. raise ValueError(
  1550. "Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead."
  1551. )
  1552. if any(p.requires_grad for p in model.parameters()):
  1553. kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
  1554. # TODO: Look at enabling native TP training directly with a proper config
  1555. if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true":
  1556. if self.device.type == "hpu":
  1557. device_ids, output_device = [self.device.index], self.device.index
  1558. else:
  1559. device_ids, output_device = [self.local_process_index], self.local_process_index
  1560. else:
  1561. device_ids, output_device = None, None
  1562. model = torch.nn.parallel.DistributedDataParallel(
  1563. model, device_ids=device_ids, output_device=output_device, **kwargs
  1564. )
  1565. if self.ddp_handler is not None:
  1566. self.ddp_handler.register_comm_hook(model)
  1567. elif self.parallelism_config and self.parallelism_config.tp_enabled:
  1568. if not hasattr(model, "tp_size"):
  1569. raise NotImplementedError(
  1570. "Model should undergo tensor parallel before passing it to accelerate."
  1571. "You can use .from_pretrained(..., tp_plan='auto') if the model supports"
  1572. )
  1573. if model.tp_size != self.parallelism_config.tp_size:
  1574. raise ValueError(
  1575. f"tp_size in the plugin {self.parallelism_config.tp_size} should be same as model's tp size {model.tp_size}"
  1576. )
  1577. elif self.is_fsdp2:
  1578. raise ValueError(
  1579. "FSDP2 preparation should be done via `accelerate.prepare()`, as it requires a model and an optimizer."
  1580. )
  1581. elif self.distributed_type == DistributedType.FSDP:
  1582. # We need to fix the optimizer *before* sharding the model
  1583. from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
  1584. # Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
  1585. # don't wrap it again
  1586. # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
  1587. # is a FSDP model, don't wrap it again
  1588. is_type_fsdp = isinstance(model, FSDP) or (
  1589. is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
  1590. )
  1591. if not is_type_fsdp:
  1592. self.state.fsdp_plugin.set_auto_wrap_policy(model)
  1593. fsdp_plugin = self.state.fsdp_plugin
  1594. # need to ensure that params are re-tied after running
  1595. # param_init_fn
  1596. fsdp_plugin.param_init_fn = ensure_weights_retied(
  1597. fsdp_plugin.param_init_fn,
  1598. model,
  1599. self.device,
  1600. )
  1601. kwargs = {
  1602. # We fallback to reshard_after_forward if sharding_strategy is not set.
  1603. # We prerfer sharding_strategy to not break the behavior of the existing code.
  1604. # Deprecation warning has already been issued in `utils.dataclasses.py`
  1605. "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
  1606. "cpu_offload": fsdp_plugin.cpu_offload,
  1607. "auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
  1608. "mixed_precision": fsdp_plugin.mixed_precision_policy,
  1609. "sync_module_states": fsdp_plugin.sync_module_states,
  1610. "backward_prefetch": fsdp_plugin.backward_prefetch,
  1611. "forward_prefetch": fsdp_plugin.forward_prefetch,
  1612. "use_orig_params": fsdp_plugin.use_orig_params,
  1613. "param_init_fn": fsdp_plugin.param_init_fn,
  1614. "ignored_modules": fsdp_plugin.ignored_modules,
  1615. "limit_all_gathers": fsdp_plugin.limit_all_gathers,
  1616. "device_id": self.device,
  1617. }
  1618. if isinstance(kwargs["ignored_modules"], str):
  1619. reg = re.compile(kwargs["ignored_modules"])
  1620. ignored = []
  1621. for name, module in model.named_modules():
  1622. if reg.fullmatch(name):
  1623. # ensure that the device for these modules is still set correctly
  1624. module.to(self.device)
  1625. ignored.append(module)
  1626. kwargs["ignored_modules"] = ignored
  1627. model = FSDP(model, **kwargs)
  1628. if fsdp_plugin.activation_checkpointing:
  1629. from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
  1630. CheckpointImpl,
  1631. apply_activation_checkpointing,
  1632. checkpoint_wrapper,
  1633. )
  1634. apply_activation_checkpointing(
  1635. model,
  1636. checkpoint_wrapper_fn=functools.partial(
  1637. checkpoint_wrapper,
  1638. checkpoint_impl=CheckpointImpl.NO_REENTRANT,
  1639. ),
  1640. auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
  1641. )
  1642. # In the event the model had been loaded in low precision, but
  1643. # mixed precision had also been activated, then we follow DeepSpeed's
  1644. # strategy to hold the parameters in full precision.
  1645. # - assume that trainer.args.bf16 and trainer.args.fp16 are already checked against
  1646. # fsdp_plugin.mixed_precision_policy.
  1647. # - NOTE: we do not check the mixed_precision attribute on the FSDP root wrapper.
  1648. # * this attribute will always set by init_utils.init_core_state so its always not None.
  1649. # * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype
  1650. # * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None,
  1651. # we still want to upcast the flat_param.
  1652. if self.mixed_precision != "no": # if mixed precision is set
  1653. upcasted_log = []
  1654. for module in FSDP.fsdp_modules(model):
  1655. # Referencing DeepSpeed Zero3
  1656. # - in Init, params are converted to 16bit while partitioning.
  1657. # - in accelerator.prepare, deepspeed.initialize is called to:
  1658. # * creates the DeepSpeedEngine.
  1659. # * since zero_optimization() is True , calls engine._configure_zero_optimizer.
  1660. #
  1661. # Inside the DeepSpeed Zero3 optimizer configuration, which initializes
  1662. # DeepSpeedZeroOptimizer_Stage3, during which:
  1663. # * trainable_param_groups are obtained from the attached optimizer
  1664. # (already partitioned in 16bit).
  1665. # * then _setup_for_real_optimizer -> _create_fp32_partitions
  1666. # which performs the fp32 upcasting.
  1667. # To mimic DeepSeepds's casting in FSDP, we look at the (single) FlatParameter held
  1668. # within an FSDP wrapper. This FlatParameter will be seen by the optimizer.
  1669. # - even though there is a torch.device('meta') guard below, we
  1670. # expect _init_utils._init_param_handle_from_module to already
  1671. # sync the parameter.
  1672. if not module._has_params:
  1673. continue # skip if FSDP module not managing parameters
  1674. param = module._flat_param
  1675. if (
  1676. param.dtype != torch.float32
  1677. and param.device != torch.device("meta")
  1678. and param.requires_grad
  1679. ):
  1680. # keep log of names_params that was upcasted
  1681. # NOTE: resorted to this because warnings.simplefilter("once") is somehow not working
  1682. name_param_log = (module.module.__class__.__name__, ", ".join(module._flat_param._fqns))
  1683. if name_param_log not in upcasted_log:
  1684. upcasted_log.append(name_param_log)
  1685. # this works because of FSDP's _runtime_utils.lazy_init.
  1686. # Have to be careful not to call anything before this that
  1687. # triggers lazy_init (e.g., _is_fsdp_root).
  1688. param.data = param.data.to(torch.float32) # upcasting
  1689. module._handle._orig_param_dtype = torch.float32 # update
  1690. # report the warnings
  1691. # some messages can be quite repetitive, especially when reporting about layers that have identical architecture.
  1692. if self.is_main_process:
  1693. for name_log, param_log in upcasted_log:
  1694. warnings.warn(
  1695. f"Upcasted low precision parameters in {name_log} because mixed precision turned on in FSDP. "
  1696. f"Affects: {param_log}."
  1697. )
  1698. if len(upcasted_log) > 0:
  1699. warnings.warn(
  1700. "FSDP upcast of low precision parameters may affect the precision of model checkpoints."
  1701. )
  1702. # if the previous and current models are same, delete the previous one
  1703. if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
  1704. del self._models[-2]
  1705. self._models[-1] = model
  1706. elif self.distributed_type == DistributedType.MULTI_CPU:
  1707. kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler else {}
  1708. model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
  1709. if self.ddp_handler is not None:
  1710. self.ddp_handler.register_comm_hook(model)
  1711. elif self.distributed_type == DistributedType.XLA and self.state.fork_launched:
  1712. model = xmp.MpModelWrapper(model).to(self.device)
  1713. # Now we can apply the FP8 autocast
  1714. if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:
  1715. model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
  1716. # torch.compile should be called last and only if the model isn't already compiled
  1717. if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
  1718. if self.state.dynamo_plugin.use_regional_compilation:
  1719. model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
  1720. else:
  1721. model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
  1722. return model
  1723. def _prepare_ao(self, *args):
  1724. if not is_torchao_available():
  1725. raise ImportError(
  1726. "`torchao` was not found on your system or is too old of a version. Please ensure that `torchao >= 0.6.1` is installed"
  1727. )
  1728. if self.is_fsdp2:
  1729. models = [x for x in args if isinstance(x, torch.nn.Module)]
  1730. optimizers = [x for x in args if isinstance(x, torch.optim.Optimizer)]
  1731. for arg in args:
  1732. if isinstance(arg, torch.nn.Module):
  1733. convert_model_to_fp8_ao(
  1734. arg,
  1735. config=self.ao_recipe_handler.config,
  1736. module_filter_func=self.ao_recipe_handler.module_filter_func,
  1737. )
  1738. # Invariant: with FSDP2, optimizer is always passed to `prepare()` together with model
  1739. # We only precompute scales if float8 all gather is enabled, possibly can add a flag for this later
  1740. if self.is_fsdp2 and len(optimizers) > 0 and self.ao_recipe_handler.config.enable_fsdp_float8_all_gather:
  1741. from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
  1742. optimizers[0].register_step_post_hook(
  1743. lambda *args, **kwargs: precompute_float8_dynamic_scale_for_fsdp(models[0])
  1744. )
  1745. return args
  1746. def _prepare_te(self, *args):
  1747. if not is_transformer_engine_available():
  1748. raise ImportError(
  1749. "`transformer_engine` was not found on your system. Please ensure that `transformer_engine` is installed"
  1750. )
  1751. model, optimizer = None, None
  1752. num_models, num_optimizers = 0, 0
  1753. result = [obj for obj in args]
  1754. for obj in result:
  1755. if isinstance(obj, torch.nn.Module):
  1756. model = obj
  1757. num_models += 1
  1758. elif isinstance(obj, (torch.optim.Optimizer)):
  1759. optimizer = obj
  1760. num_optimizers += 1
  1761. if optimizer is None and model is None:
  1762. return result
  1763. elif optimizer is None or model is None:
  1764. raise ValueError(
  1765. "You must pass a model and an optimizer together to `accelerate.prepare()` when using TransformerEngine."
  1766. )
  1767. elif num_models > 1 or num_optimizers > 1:
  1768. raise ValueError(
  1769. f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with TransformerEngine."
  1770. )
  1771. old_named_params = self._get_named_parameters(model)
  1772. with torch.no_grad():
  1773. convert_model(model)
  1774. new_named_params = self._get_named_parameters(model)
  1775. mapping = {p: new_named_params[n] for n, p in old_named_params.items()}
  1776. # We need to switch the optimizer params to the new params *after* the model is wrapped in FSDP
  1777. for param_group in optimizer.param_groups:
  1778. param_group["params"] = [mapping[p] for p in param_group["params"]]
  1779. return result
  1780. def _prepare_deepspeed(self, *args):
  1781. import deepspeed
  1782. ds_initialize = deepspeed.initialize
  1783. if self.fp8_backend == FP8BackendType.MSAMP:
  1784. # MS-AMP requires DeepSpeed patches
  1785. from msamp import deepspeed as msamp_deepspeed
  1786. ds_initialize = msamp_deepspeed.initialize
  1787. deepspeed_plugin = self.deepspeed_plugin
  1788. is_dataloader_present = any(isinstance(obj, torch.utils.data.DataLoader) for obj in args)
  1789. tp_size = deepspeed_plugin.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0)
  1790. sp_backend = self.parallelism_config.sp_backend if self.parallelism_config else None
  1791. sp_size = self.parallelism_config.sp_size if self.parallelism_config else 1
  1792. sp_handler = self.parallelism_config.sp_handler if self.parallelism_config else None
  1793. if tp_size > 1:
  1794. if not compare_versions("deepspeed", ">=", "0.16.4"):
  1795. raise ImportError(
  1796. "Deepspeed TP requires deepspeed >= 0.16.4, Please update DeepSpeed via `pip install deepspeed -U`."
  1797. )
  1798. if not is_torch_version(">=", "2.2.0"):
  1799. raise ImportError(
  1800. "Tried to use TP, but `torch.distributed.device_mesh` requires PyTorch >= 2.2.0. Please upgrade your PyTorch version"
  1801. )
  1802. from torch.distributed.device_mesh import init_device_mesh
  1803. mesh_dim_name = "tp"
  1804. self.state.ds_device_mesh = init_device_mesh(self.device.type, (tp_size,), mesh_dim_names=(mesh_dim_name,))
  1805. result = [
  1806. self._prepare_one(obj, first_pass=True) if isinstance(obj, torch.utils.data.DataLoader) else obj
  1807. for obj in args
  1808. ]
  1809. if deepspeed_plugin.is_auto("train_micro_batch_size_per_gpu"):
  1810. if is_dataloader_present:
  1811. batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
  1812. if any(bs is None for bs in batch_sizes):
  1813. raise ValueError(
  1814. "At least one of the dataloaders passed to `accelerate.prepare()` has `None` as batch size. "
  1815. "Please set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
  1816. "or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
  1817. )
  1818. if self.split_batches:
  1819. batch_sizes = [batch_size // self.num_processes for batch_size in batch_sizes]
  1820. batch_size_per_device = min(batch_sizes) if deepspeed_plugin.is_train_batch_min else max(batch_sizes)
  1821. if len(batch_sizes) > 1:
  1822. logger.info(
  1823. "Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
  1824. f"{deepspeed_plugin.is_train_batch_min} will decide the `train_batch_size` ({batch_size_per_device})."
  1825. )
  1826. else:
  1827. raise ValueError(
  1828. "When using DeepSpeed, `accelerate.prepare()` requires you to pass at least one of training or evaluation dataloaders "
  1829. "with `batch_size` attribute returning an integer value "
  1830. "or alternatively set an integer value in `train_micro_batch_size_per_gpu` in the deepspeed config file "
  1831. "or assign integer value to `AcceleratorState().deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu']`."
  1832. )
  1833. else:
  1834. batch_size_per_device = deepspeed_plugin.get_value("train_micro_batch_size_per_gpu")
  1835. # handle `gradient_accumulation_steps` when the value is `auto`
  1836. deepspeed_plugin.fill_match(
  1837. "gradient_accumulation_steps",
  1838. must_match=False,
  1839. gradient_accumulation_steps=self.gradient_accumulation_steps,
  1840. )
  1841. deepspeed_gradient_accumulation_steps = deepspeed_plugin.get_value("gradient_accumulation_steps")
  1842. # update gradient_accumulation_steps if there is a mismatch
  1843. if deepspeed_gradient_accumulation_steps != self.gradient_accumulation_steps:
  1844. logger.warning(
  1845. f"Gradient accumulation steps mismatch: GradientAccumulationPlugin has {self.gradient_accumulation_steps}, "
  1846. f"DeepSpeed config has {deepspeed_gradient_accumulation_steps}. Using DeepSpeed's value."
  1847. )
  1848. self.gradient_accumulation_steps = deepspeed_gradient_accumulation_steps
  1849. config_kwargs = {
  1850. "gradient_clipping": 1.0,
  1851. "zero_optimization.stage3_gather_16bit_weights_on_model_save": False,
  1852. }
  1853. # This block is skipped when preparing just a model and DL is absent from current call's args
  1854. if batch_size_per_device is not None:
  1855. config_kwargs["train_micro_batch_size_per_gpu"] = batch_size_per_device
  1856. config_kwargs["train_batch_size"] = (
  1857. batch_size_per_device
  1858. * deepspeed_plugin.get_value("gradient_accumulation_steps")
  1859. * self.num_processes
  1860. // sp_size
  1861. )
  1862. model = None
  1863. optimizer = None
  1864. scheduler = None
  1865. for obj in result:
  1866. if isinstance(obj, torch.nn.Module):
  1867. model = obj
  1868. elif isinstance(obj, (torch.optim.Optimizer, DummyOptim)):
  1869. optimizer = obj
  1870. elif (isinstance(obj, (LRScheduler, DummyScheduler))) or (
  1871. type(obj).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
  1872. ):
  1873. scheduler = obj
  1874. if optimizer is not None:
  1875. if "optimizer" in deepspeed_plugin.deepspeed_config and not isinstance(optimizer, (DummyOptim)):
  1876. raise ValueError(
  1877. "You cannot specify an optimizer in the config file and in the code at the same time. "
  1878. "Please remove the optimizer from the config file or "
  1879. "create `accelerate.utils.DummyOptim` in the code."
  1880. )
  1881. elif "optimizer" not in deepspeed_plugin.deepspeed_config and isinstance(optimizer, (DummyOptim)):
  1882. raise ValueError(
  1883. "You cannot create a `DummyOptim` without specifying an optimizer in the config file."
  1884. )
  1885. if isinstance(optimizer, (torch.optim.Optimizer)):
  1886. deepspeed_plugin.deepspeed_config["zero_allow_untested_optimizer"] = True
  1887. if scheduler is not None:
  1888. if "scheduler" in deepspeed_plugin.deepspeed_config and not isinstance(scheduler, (DummyScheduler)):
  1889. raise ValueError(
  1890. "You cannot specify a scheduler in the config file and in the code at the same time. "
  1891. "Please remove the scheduler from the config file or "
  1892. "create `accelerate.utils.DummyScheduler` in the code."
  1893. )
  1894. elif (
  1895. "scheduler" not in deepspeed_plugin.deepspeed_config
  1896. and isinstance(scheduler, (DummyScheduler))
  1897. and scheduler.lr_scheduler_callable is None
  1898. ):
  1899. raise ValueError(
  1900. "Either specify a scheduler in the config file or "
  1901. "pass in the `lr_scheduler_callable` parameter when using `accelerate.utils.DummyScheduler`."
  1902. )
  1903. if optimizer is not None and scheduler is not None:
  1904. if isinstance(optimizer, (DummyOptim)) and not isinstance(scheduler, (DummyScheduler)):
  1905. raise ValueError(
  1906. "You can only specify `accelerate.utils.DummyScheduler` in the code when using "
  1907. "`accelerate.utils.DummyOptim`."
  1908. )
  1909. if model is not None:
  1910. # If we are using FP8, we need to apply the autowrap now
  1911. if self.fp8_backend == FP8BackendType.TE:
  1912. model = apply_fp8_autowrap(model, self.fp8_recipe_handler)
  1913. # if the model is an MOE, set the appropriate MOE layers as leaf Z3 modules
  1914. deepspeed_plugin.set_moe_leaf_modules(model)
  1915. # deal with config keys that use `auto` value and rely on model's hidden_size
  1916. hidden_size_based_keys = [
  1917. "zero_optimization.reduce_bucket_size",
  1918. "zero_optimization.stage3_prefetch_bucket_size",
  1919. "zero_optimization.stage3_param_persistence_threshold",
  1920. ]
  1921. hidden_size_auto_keys = [x for x in hidden_size_based_keys if deepspeed_plugin.is_auto(x)]
  1922. if len(hidden_size_auto_keys) > 0:
  1923. reasoning = (
  1924. "therefore it's not possible to automatically fill out the following `auto` entries "
  1925. + f"in the DeepSpeed config file: {hidden_size_auto_keys}. You can fix that by replacing "
  1926. + "`auto` values for these keys with an integer value of your choice."
  1927. )
  1928. if not hasattr(model, "config"):
  1929. raise ValueError("Can't find `model.config` entry, " + reasoning)
  1930. if hasattr(model.config, "hidden_size"):
  1931. hidden_size = model.config.hidden_size
  1932. elif hasattr(model.config, "hidden_sizes"):
  1933. # if there are many hidden sizes pick the largest one
  1934. hidden_size = max(model.config.hidden_sizes)
  1935. else:
  1936. raise ValueError(
  1937. "Can find neither `model.config.hidden_size` nor `model.config.hidden_sizes`, " + reasoning
  1938. )
  1939. config_kwargs.update(
  1940. {
  1941. "zero_optimization.reduce_bucket_size": hidden_size * hidden_size,
  1942. "zero_optimization.stage3_prefetch_bucket_size": int(0.9 * hidden_size * hidden_size),
  1943. "zero_optimization.stage3_param_persistence_threshold": 10 * hidden_size,
  1944. }
  1945. )
  1946. if isinstance(optimizer, (DummyOptim)):
  1947. config_kwargs.update(
  1948. {"optimizer.params.lr": optimizer.lr, "optimizer.params.weight_decay": optimizer.weight_decay}
  1949. )
  1950. if isinstance(scheduler, (DummyScheduler)) and scheduler.lr_scheduler_callable is None:
  1951. max_lr = (
  1952. getattr(scheduler.optimizer, "lr", None)
  1953. if getattr(scheduler.optimizer, "defaults", None) is None
  1954. else scheduler.optimizer.defaults["lr"]
  1955. )
  1956. config_kwargs.update(
  1957. {
  1958. "scheduler.params.warmup_min_lr": 0,
  1959. "scheduler.params.warmup_max_lr": max_lr,
  1960. "scheduler.params.warmup_num_steps": scheduler.warmup_num_steps,
  1961. }
  1962. )
  1963. if scheduler.total_num_steps is not None:
  1964. config_kwargs["scheduler.params.total_num_steps"] = (
  1965. math.ceil(scheduler.total_num_steps / self.num_processes)
  1966. if not self.split_batches
  1967. else scheduler.total_num_steps
  1968. )
  1969. deepspeed_plugin.deepspeed_config_process(must_match=False, **config_kwargs)
  1970. self.deepspeed_config = deepspeed_plugin.deepspeed_config
  1971. # note: batch_size derivation is all over the map, especiall in HF Trainer, so try to fix it at the last moment if needed
  1972. pc = self.parallelism_config
  1973. if pc is not None and pc.sp_backend == "deepspeed" and pc.sp_size > 1:
  1974. self.deepspeed_config["train_batch_size"] = (
  1975. self.deepspeed_config["train_micro_batch_size_per_gpu"]
  1976. * self.deepspeed_config["gradient_accumulation_steps"]
  1977. * pc.data_parallel_size
  1978. )
  1979. kwargs = dict(model=model, config_params=self.deepspeed_config)
  1980. if optimizer is not None:
  1981. if isinstance(optimizer, (DummyOptim)):
  1982. kwargs["model_parameters"] = optimizer.params
  1983. if isinstance(scheduler, (DummyScheduler)) and scheduler.lr_scheduler_callable is not None:
  1984. kwargs["lr_scheduler"] = scheduler.lr_scheduler_callable
  1985. else:
  1986. if self.deepspeed_config["zero_optimization"].get("offload_optimizer", {}).get(
  1987. "device", "none"
  1988. ) != "none" and self.deepspeed_config.get("zero_force_ds_cpu_optimizer", True):
  1989. if self.device.type == "hpu" and os.environ.get("PT_HPU_LAZY_MODE", "1") == "1":
  1990. raise ValueError(
  1991. "You can't use an Offload Optimizer with HPU in Lazy Mode. "
  1992. "Please set the environment variable `PT_HPU_LAZY_MODE` to `0`."
  1993. )
  1994. optimizer = map_pytorch_optim_to_deepspeed(optimizer)
  1995. kwargs["optimizer"] = optimizer
  1996. if scheduler is not None:
  1997. if type(scheduler).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES:
  1998. kwargs["lr_scheduler"] = scheduler
  1999. if self.device.type == "hpu":
  2000. # This env variable is initialized here to make sure it is set to "true"
  2001. # It should be done by the launcher but it does not work for multi-node runs
  2002. os.environ["DEEPSPEED_USE_HPU"] = "true"
  2003. mpu = None
  2004. if sp_size > 1:
  2005. if sp_backend != "deepspeed":
  2006. raise ValueError(
  2007. f"In order to use the configured {sp_size=} with DeepSpeed, you need to configure sp_backend='deepspeed', yet you configured it to be {sp_backend=}."
  2008. )
  2009. ver_min_required = "0.18.2"
  2010. if not compare_versions("deepspeed", ">=", ver_min_required):
  2011. raise ImportError(
  2012. f"Deepspeed ALST/Ulysses requires deepspeed>={ver_min_required}. Please update DeepSpeed via `pip install deepspeed -U`."
  2013. )
  2014. from deepspeed.runtime.sequence_parallel.ulysses_sp import (
  2015. UlyssesSPAttentionHF,
  2016. UlyssesSPDataLoaderAdapter,
  2017. )
  2018. if not hasattr(model, "config"):
  2019. raise ValueError(
  2020. "UlyssesSPAttentionHF currently works with HF Transformers and expects the model object to have a config attribute but this model doesn't have one."
  2021. )
  2022. mpu = UlyssesSPAttentionHF.register_with_transformers(
  2023. model_name_or_path=model,
  2024. sequence_parallel_size=sp_size,
  2025. seq_length=sp_handler.sp_seq_length,
  2026. seq_length_is_variable=sp_handler.sp_seq_length_is_variable,
  2027. core_attn_implementation=sp_handler.sp_attn_implementation,
  2028. micro_batch_size=batch_size_per_device,
  2029. )
  2030. kwargs["mpu"] = mpu
  2031. for i in range(len(result)):
  2032. if isinstance(result[i], torch.utils.data.DataLoader):
  2033. if sp_size > 1:
  2034. # note that in case dataloader was prepared apart from model (for the external accelerator.prepare call) you'd need to call deepspeed_ulysses_dl_adapter after prepare(model) (see HF Trainer as the use-case)
  2035. sp_group = mpu.get_sequence_parallel_group()
  2036. sp_world_size = mpu.get_sequence_parallel_world_size()
  2037. sp_rank = mpu.get_sequence_parallel_rank()
  2038. result[i] = UlyssesSPDataLoaderAdapter(
  2039. result[i],
  2040. sp_rank=sp_rank,
  2041. sp_group=sp_group,
  2042. sp_world_size=sp_world_size,
  2043. device=self.device, # model.device,
  2044. )
  2045. engine, optimizer, _, lr_scheduler = ds_initialize(**kwargs)
  2046. if compare_versions("deepspeed", ">=", "0.14.4") and self.state.dynamo_plugin.backend != DynamoBackend.NO:
  2047. compile_kwargs = self.state.dynamo_plugin.to_kwargs()
  2048. if self.state.dynamo_plugin.use_regional_compilation:
  2049. compile_regions_deepspeed(engine.module, **compile_kwargs)
  2050. else:
  2051. engine.compile(backend=compile_kwargs.pop("backend"), compile_kwargs=compile_kwargs)
  2052. if optimizer is not None:
  2053. optimizer = DeepSpeedOptimizerWrapper(optimizer)
  2054. if scheduler is not None:
  2055. if lr_scheduler is None:
  2056. scheduler = AcceleratedScheduler(
  2057. scheduler,
  2058. optimizer,
  2059. step_with_optimizer=self.step_scheduler_with_optimizer,
  2060. split_batches=self.split_batches,
  2061. )
  2062. else:
  2063. scheduler = DeepSpeedSchedulerWrapper(lr_scheduler, optimizer)
  2064. for i in range(len(result)):
  2065. if isinstance(result[i], torch.nn.Module):
  2066. result[i] = engine
  2067. elif isinstance(result[i], (torch.optim.Optimizer, DummyOptim)):
  2068. result[i] = optimizer
  2069. elif (isinstance(result[i], (LRScheduler, DummyScheduler))) or (
  2070. type(result[i]).__name__ in deepspeed.runtime.lr_schedules.VALID_LR_SCHEDULES
  2071. ):
  2072. result[i] = scheduler
  2073. # pointing for deepspeed_engine_wrapped.backward()
  2074. if self.deepspeed_engine_wrapped is None:
  2075. self.deepspeed_engine_wrapped = DeepSpeedEngineWrapper(engine)
  2076. else:
  2077. logger.warning(
  2078. "A wrapped DeepSpeed engine reference is currently tied for this `Accelerator()` instance. "
  2079. "If you want to call `accelerator.backward()` referencing a new model/engine, "
  2080. "please create a separate `Accelerator()` instance and call `accelerator.prepare()` on it."
  2081. )
  2082. self._models.append(engine)
  2083. if optimizer is not None:
  2084. self._optimizers.append(optimizer)
  2085. if scheduler is not None:
  2086. self._schedulers.append(scheduler)
  2087. return tuple(result)
  2088. def deepspeed_ulysses_dl_adapter(self, dl, model):
  2089. """this is normally called as part of `prepare` but when dataloader was prepared apart from model (for the external accelerator.prepare call) this additional call needs to be made after prepare(model) (see HF Trainer as the use-case)"""
  2090. sp_size = self.parallelism_config.sp_size if self.parallelism_config else 1
  2091. if sp_size == 1:
  2092. return dl
  2093. from deepspeed.runtime.sequence_parallel.ulysses_sp import UlyssesSPDataLoaderAdapter
  2094. from deepspeed.utils import groups
  2095. sp_group = groups._get_sequence_parallel_group()
  2096. sp_world_size = groups._get_sequence_parallel_world_size()
  2097. sp_rank = groups._get_sequence_parallel_rank()
  2098. dl = UlyssesSPDataLoaderAdapter(
  2099. dl,
  2100. sp_rank=sp_rank,
  2101. sp_group=sp_group,
  2102. sp_world_size=sp_world_size,
  2103. device=model.device,
  2104. )
  2105. return dl
  2106. def _prepare_megatron_lm(self, *args):
  2107. megatron_lm_plugin = self.state.megatron_lm_plugin
  2108. micro_batch_size = None
  2109. if not megatron_lm_plugin.megatron_dataset_flag:
  2110. batch_sizes = [obj.batch_size for obj in args if hasattr(obj, "batch_size")]
  2111. if len(batch_sizes) == 0:
  2112. raise ValueError(
  2113. "You must specify a training or evaluation dataloader in `accelerate.prepare()` when using Megatron-LM."
  2114. )
  2115. micro_batch_size = min(batch_sizes) if megatron_lm_plugin.is_train_batch_min else max(batch_sizes)
  2116. if len(batch_sizes) > 1:
  2117. logger.info(
  2118. "Since you passed both train and evaluation dataloader, `is_train_batch_min` (here "
  2119. f"{megatron_lm_plugin.is_train_batch_min} will decide the `train_batch_size` ({micro_batch_size})."
  2120. )
  2121. else:
  2122. for obj in args:
  2123. if isinstance(obj, MegatronLMDummyDataLoader):
  2124. micro_batch_size = obj.dataset_args["micro_batch_size"]
  2125. break
  2126. if micro_batch_size is not None:
  2127. dp_degree = self.num_processes // (megatron_lm_plugin.tp_degree * megatron_lm_plugin.pp_degree)
  2128. megatron_lm_plugin.set_training_args(micro_batch_size, dp_degree)
  2129. else:
  2130. raise ValueError(
  2131. "When you do not pass the dataloader parameter, the `data_parallel_size`, "
  2132. "`micro_batch_size`, and `global_batch_size` megatron parameters will not be updated."
  2133. )
  2134. model = None
  2135. optimizer = None
  2136. scheduler = None
  2137. batch_data = None
  2138. for obj in args:
  2139. if isinstance(obj, torch.utils.data.DataLoader) and batch_data is None:
  2140. batch_data = next(iter(obj))
  2141. elif isinstance(obj, torch.nn.Module):
  2142. model = obj
  2143. elif isinstance(obj, (torch.optim.Optimizer)):
  2144. optimizer = obj
  2145. elif isinstance(obj, (LRScheduler, MegatronLMDummyScheduler)):
  2146. scheduler = obj
  2147. if model is not None:
  2148. megatron_lm_plugin.set_network_size_args(model, batch_data)
  2149. if optimizer is not None:
  2150. megatron_lm_plugin.set_optimizer_type(optimizer)
  2151. if scheduler is not None:
  2152. if not isinstance(scheduler, MegatronLMDummyScheduler):
  2153. raise ValueError(
  2154. "You can't use a custom scheduler with Megatron-LM. Please use the `accelerate.utils.MegatronLMDummyScheduler` instead."
  2155. )
  2156. megatron_lm_plugin.set_scheduler_args(scheduler)
  2157. # initialize megatron-lm
  2158. megatron_lm_initialize(self, args_defaults=megatron_lm_plugin.megatron_lm_default_args)
  2159. (model, optimizer, scheduler) = megatron_lm_prepare_model_optimizer_scheduler(self)
  2160. self.wait_for_everyone()
  2161. counter = 0
  2162. result = []
  2163. for obj in args:
  2164. if isinstance(obj, torch.utils.data.DataLoader):
  2165. result.append(megatron_lm_prepare_data_loader(self, obj))
  2166. counter += 1
  2167. elif isinstance(obj, MegatronLMDummyDataLoader):
  2168. if counter == 0:
  2169. obj.set_megatron_data_args()
  2170. dataloaders = megatron_lm_prepare_data_loader(self, obj)
  2171. result.append(dataloaders[counter])
  2172. counter += 1
  2173. else:
  2174. result.append(obj)
  2175. if model is not None:
  2176. model = MegatronEngine(self, model, optimizer, scheduler)
  2177. if optimizer is not None:
  2178. optimizer = MegatronLMOptimizerWrapper(optimizer)
  2179. if scheduler is not None:
  2180. scheduler = MegatronLMSchedulerWrapper(scheduler, optimizer)
  2181. for i in range(len(result)):
  2182. if isinstance(result[i], torch.nn.Module):
  2183. result[i] = model
  2184. elif isinstance(result[i], torch.optim.Optimizer):
  2185. result[i] = optimizer
  2186. elif isinstance(result[i], MegatronLMDummyScheduler):
  2187. result[i] = scheduler
  2188. if model is not None:
  2189. self._models.append(model)
  2190. if len(self._models) > 1:
  2191. raise AssertionError(
  2192. "You can't use same `Accelerator()` instance with multiple models when using Megatron-LM"
  2193. )
  2194. if optimizer is not None:
  2195. self._optimizers.append(optimizer)
  2196. if scheduler is not None:
  2197. self._schedulers.append(scheduler)
  2198. return tuple(result)
  2199. def _prepare_ipex(self, *args):
  2200. """
  2201. Prepares model and optimizer for training with IPEX on CPU/XPU. This covers 3 cases, IPEX compiled with CPU
  2202. only support, IPEX compiled with XPU support and training with XPU pytorch backend available in stock pytorch
  2203. starting from version 2.4.
  2204. """
  2205. # ipex.optimize() is available only for IPEX, both IPEX-CPU and IPEX-XPU
  2206. if is_ipex_available():
  2207. import intel_extension_for_pytorch as ipex
  2208. else:
  2209. raise ImportError(
  2210. "IPEX is not installed or IPEX's version does not match current PyTorch version. Please refer"
  2211. " to https://github.com/intel/intel-extension-for-pytorch."
  2212. )
  2213. models = []
  2214. optimizers = []
  2215. result = [obj for obj in args]
  2216. for i, obj in enumerate(result):
  2217. if isinstance(obj, torch.nn.Module):
  2218. model = obj
  2219. model.train()
  2220. models.append((i, model))
  2221. elif isinstance(obj, (torch.optim.Optimizer)):
  2222. optimizers.append((i, obj))
  2223. # Impossible to determine what to do if multiple models and/or optimizers are provided
  2224. if len(optimizers) > 1 or (len(models) > 1 and len(optimizers) == 1):
  2225. raise ValueError(
  2226. "Prepare with IPEX expects either 1+ models and no optimizer OR a single model-optimizer pair."
  2227. )
  2228. # Nothing to do
  2229. if len(models) == 0 and len(optimizers) == 0:
  2230. return result
  2231. dtype = torch.bfloat16 if self.state.mixed_precision == "bf16" else None
  2232. # Multiple models and no optimizer (inference) are provided
  2233. if len(models) > 0 and len(optimizers) == 0:
  2234. for i, model in models:
  2235. if self.device.type == "xpu" and next(model.parameters()).device.type == "cpu":
  2236. model = model.to(self.device)
  2237. model, _ = ipex.optimize(model, optimizer=None, dtype=dtype, inplace=True, level="O1")
  2238. # Replace in result
  2239. result[i] = model
  2240. # A single model-optimizer pair (training) is provided
  2241. if len(models) == 1 and len(optimizers) == 1:
  2242. i_model, model = models[0]
  2243. i_optimizer, optimizer = optimizers[0]
  2244. if self.device.type == "xpu" and next(model.parameters()).device.type == "cpu":
  2245. model = model.to(self.device)
  2246. model, optimizer = ipex.optimize(model, optimizer=optimizer, dtype=dtype, inplace=True, level="O1")
  2247. # Replace in result
  2248. result[i_model] = model
  2249. result[i_optimizer] = optimizer
  2250. return tuple(result)
  2251. def _prepare_device_mesh(self):
  2252. """
  2253. Prepare the device mesh for distributed training. The dataloader will determine how to load data based on the
  2254. device mesh.
  2255. """
  2256. if self.distributed_type == DistributedType.DEEPSPEED and hasattr(self.state, "ds_device_mesh"):
  2257. return self.state.ds_device_mesh
  2258. else:
  2259. return self.torch_device_mesh
  2260. def _prepare_msamp(self, *args, device_placement):
  2261. if not is_msamp_available():
  2262. raise ImportError(
  2263. "MS-AMP was not found on your system. Please ensure that MS-AMP is available "
  2264. " or choose `'te'` as the backend for FP8 mixed precision training."
  2265. )
  2266. # We've already checked for FSDP + MS-AMP during `__init__`
  2267. import msamp
  2268. model, optimizer = None, None
  2269. optimizer_index = None
  2270. num_models, num_optimizers = 0, 0
  2271. result = [obj for obj in args]
  2272. for i, obj in enumerate(result):
  2273. if isinstance(obj, torch.nn.Module):
  2274. model = obj
  2275. num_models += 1
  2276. elif isinstance(obj, (torch.optim.Optimizer)):
  2277. optimizer = obj
  2278. optimizer_index = i
  2279. num_optimizers += 1
  2280. # DataLoader/Scheduler case
  2281. if optimizer is None and model is None:
  2282. return result, device_placement
  2283. elif optimizer is None or model is None:
  2284. raise ValueError(
  2285. "You must pass a model and an optimizer together to `accelerate.prepare()` when using MS-AMP."
  2286. )
  2287. elif num_models > 1 or num_optimizers > 1:
  2288. raise ValueError(
  2289. f"You can't use multiple models ({num_models}) or optimizers {num_optimizers} with MS-AMP."
  2290. )
  2291. else:
  2292. # DEPRECATE @ 2.0
  2293. if self.fp8_recipe_handler is not None:
  2294. opt_level = self.fp8_recipe_handler.opt_level
  2295. else:
  2296. opt_level = self.msamp_recipe_handler.opt_level
  2297. model, optimizer = msamp.initialize(model, optimizer, opt_level=opt_level)
  2298. for i in range(len(result)):
  2299. if isinstance(result[i], torch.nn.Module):
  2300. result[i] = model
  2301. elif isinstance(result[i], (torch.optim.Optimizer)):
  2302. result[i] = optimizer
  2303. if optimizer_index is not None:
  2304. # NOTE: MS-AMP moves the optimizer, but *not* the model to the right device
  2305. device_placement[optimizer_index] = False
  2306. return tuple(result), device_placement
  2307. def prepare_data_loader(
  2308. self, data_loader: torch.utils.data.DataLoader, device_placement=None, slice_fn_for_dispatch=None
  2309. ):
  2310. """
  2311. Prepares a PyTorch DataLoader for training in any distributed setup. It is recommended to use
  2312. [`Accelerator.prepare`] instead.
  2313. Args:
  2314. data_loader (`torch.utils.data.DataLoader`):
  2315. A vanilla PyTorch DataLoader to prepare
  2316. device_placement (`bool`, *optional*):
  2317. Whether or not to place the batches on the proper device in the prepared dataloader. Will default to
  2318. `self.device_placement`.
  2319. slice_fn_for_dispatch (`Callable`, *optional*`):
  2320. If passed, this function will be used to slice tensors across `num_processes`. Will default to
  2321. [`~utils.slice_tensors`]. This argument is used only when `dispatch_batches` is set to `True` and will
  2322. be ignored otherwise.
  2323. Example:
  2324. ```python
  2325. >>> import torch
  2326. >>> from accelerate import Accelerator
  2327. >>> accelerator = Accelerator()
  2328. >>> data_loader = torch.utils.data.DataLoader(...)
  2329. >>> data_loader = accelerator.prepare_data_loader(data_loader, device_placement=True)
  2330. ```
  2331. """
  2332. # Ensure we can't double wrap a DataLoader due to `find_batch_size`
  2333. if getattr(data_loader, "_is_accelerate_prepared", False):
  2334. if data_loader not in self._dataloaders:
  2335. self._dataloaders.append(data_loader)
  2336. return data_loader
  2337. if device_placement is None:
  2338. device_placement = self.device_placement if self.distributed_type != DistributedType.XLA else False
  2339. device_mesh = self._prepare_device_mesh()
  2340. prepared_data_loader = prepare_data_loader(
  2341. data_loader,
  2342. self.device,
  2343. num_processes=self.num_processes,
  2344. process_index=self.process_index,
  2345. split_batches=self.split_batches,
  2346. put_on_device=device_placement,
  2347. rng_types=self.rng_types.copy(),
  2348. dispatch_batches=self.dispatch_batches,
  2349. even_batches=self.even_batches,
  2350. slice_fn_for_dispatch=slice_fn_for_dispatch,
  2351. use_seedable_sampler=self.use_seedable_sampler,
  2352. data_seed=self.dataloader_config.data_seed,
  2353. non_blocking=self.non_blocking,
  2354. use_stateful_dataloader=self.use_stateful_dataloader,
  2355. torch_device_mesh=device_mesh,
  2356. )
  2357. self._dataloaders.append(prepared_data_loader)
  2358. return prepared_data_loader
  2359. def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement=None):
  2360. """
  2361. Prepares a PyTorch Optimizer for training in any distributed setup. It is recommended to use
  2362. [`Accelerator.prepare`] instead.
  2363. Args:
  2364. optimizer (`torch.optim.Optimizer`):
  2365. A vanilla PyTorch optimizer to prepare
  2366. device_placement (`bool`, *optional*):
  2367. Whether or not to place the optimizer on the proper device. Will default to `self.device_placement`.
  2368. Example:
  2369. ```python
  2370. >>> import torch
  2371. >>> from accelerate import Accelerator
  2372. >>> accelerator = Accelerator()
  2373. >>> optimizer = torch.optim.Adam(...)
  2374. >>> optimizer = accelerator.prepare_optimizer(optimizer, device_placement=True)
  2375. ```
  2376. """
  2377. if is_lomo_available():
  2378. # We need to import locally to avoid circular imports since lomo imports stuff from
  2379. # transformers & accelerate
  2380. from lomo_optim import AdaLomo, Lomo
  2381. # Support multiple optimizers: https://github.com/huggingface/accelerate/pull/2695#discussion_r1589164607
  2382. self.has_lomo_optimizer |= isinstance(optimizer, (Lomo, AdaLomo))
  2383. # Ensure we can't double wrap an optimizer due to `find_batch_size`
  2384. if getattr(optimizer, "_is_accelerate_prepared", False):
  2385. if optimizer not in self._optimizers:
  2386. self._optimizers.append(optimizer)
  2387. return optimizer
  2388. if device_placement is None:
  2389. device_placement = self.device_placement
  2390. # NOTE: Special case with MS-AMP we do *not* pass in the scaler explicitly to the `AcceleratedOptimizer`,
  2391. # Their optimizer handles it for us.
  2392. scaler = None if self.fp8_backend == FP8BackendType.MSAMP else self.scaler
  2393. optimizer = AcceleratedOptimizer(optimizer, device_placement=device_placement, scaler=scaler)
  2394. self._optimizers.append(optimizer)
  2395. return optimizer
  2396. def prepare_scheduler(self, scheduler: LRScheduler):
  2397. """
  2398. Prepares a PyTorch Scheduler for training in any distributed setup. It is recommended to use
  2399. [`Accelerator.prepare`] instead.
  2400. Args:
  2401. scheduler (`torch.optim.lr_scheduler.LRScheduler`):
  2402. A vanilla PyTorch scheduler to prepare
  2403. Example:
  2404. ```python
  2405. >>> import torch
  2406. >>> from accelerate import Accelerator
  2407. >>> accelerator = Accelerator()
  2408. >>> optimizer = torch.optim.Adam(...)
  2409. >>> scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, ...)
  2410. >>> scheduler = accelerator.prepare_scheduler(scheduler)
  2411. ```
  2412. """
  2413. # Ensure we can't double wrap a scheduler due to `find_batch_size`
  2414. if getattr(scheduler, "_is_accelerate_prepared", False):
  2415. if scheduler not in self._schedulers:
  2416. self._schedulers.append(scheduler)
  2417. return scheduler
  2418. # We try to find the optimizer associated with `scheduler`, the default is the full list.
  2419. optimizer = self._optimizers
  2420. for opt in self._optimizers:
  2421. if getattr(scheduler, "optimizer", None) == opt.optimizer:
  2422. optimizer = opt
  2423. break
  2424. scheduler = AcceleratedScheduler(
  2425. scheduler,
  2426. optimizer,
  2427. step_with_optimizer=self.step_scheduler_with_optimizer,
  2428. split_batches=self.split_batches,
  2429. )
  2430. self._schedulers.append(scheduler)
  2431. return scheduler
  2432. def backward(self, loss, **kwargs):
  2433. """
  2434. Scales the gradients in accordance to the `GradientAccumulationPlugin` and calls the correct `backward()` based
  2435. on the configuration.
  2436. Should be used in lieu of `loss.backward()`.
  2437. Example:
  2438. ```python
  2439. >>> from accelerate import Accelerator
  2440. >>> accelerator = Accelerator(gradient_accumulation_steps=2)
  2441. >>> outputs = model(inputs)
  2442. >>> loss = loss_fn(outputs, labels)
  2443. >>> accelerator.backward(loss)
  2444. ```
  2445. """
  2446. learning_rate = kwargs.get("learning_rate")
  2447. if self.distributed_type != DistributedType.DEEPSPEED:
  2448. # deepspeed handles loss scaling by gradient_accumulation_steps in its `backward`
  2449. loss = loss / self.gradient_accumulation_steps
  2450. if self.distributed_type == DistributedType.DEEPSPEED:
  2451. self.deepspeed_engine_wrapped.backward(loss, sync_gradients=self.sync_gradients, **kwargs)
  2452. elif self.distributed_type == DistributedType.MEGATRON_LM:
  2453. return
  2454. elif self.scaler is not None:
  2455. self.scaler.scale(loss).backward(**kwargs)
  2456. elif learning_rate is not None and self.has_lomo_optimizer:
  2457. self.lomo_backward(loss, learning_rate)
  2458. else:
  2459. loss.backward(**kwargs)
  2460. def set_trigger(self):
  2461. """
  2462. Sets the internal trigger tensor to 1 on the current process. A latter check should follow using this which
  2463. will check across all processes.
  2464. Note:
  2465. Does not require `wait_for_everyone()`
  2466. Example:
  2467. ```python
  2468. >>> from accelerate import Accelerator
  2469. >>> accelerator = Accelerator()
  2470. >>> # Assume later in the training script
  2471. >>> # `should_do_breakpoint` is a custom function to monitor when to break,
  2472. >>> # e.g. when the loss is NaN
  2473. >>> if should_do_breakpoint(loss):
  2474. ... accelerator.set_trigger()
  2475. >>> # Assume later in the training script
  2476. >>> if accelerator.check_breakpoint():
  2477. ... break
  2478. ```
  2479. """
  2480. self.flag_tensor = torch.tensor(1, device=self.device)
  2481. def check_trigger(self):
  2482. """
  2483. Checks if the internal trigger tensor has been set to 1 in any of the processes. If so, will return `True` and
  2484. reset the trigger tensor to 0.
  2485. Note:
  2486. Does not require `wait_for_everyone()`
  2487. Example:
  2488. ```python
  2489. >>> from accelerate import Accelerator
  2490. >>> accelerator = Accelerator()
  2491. >>> # Assume later in the training script
  2492. >>> # `should_do_breakpoint` is a custom function to monitor when to break,
  2493. >>> # e.g. when the loss is NaN
  2494. >>> if should_do_breakpoint(loss):
  2495. ... accelerator.set_trigger()
  2496. >>> # Assume later in the training script
  2497. >>> if accelerator.check_trigger():
  2498. ... break
  2499. ```
  2500. """
  2501. # Now that we are outside `__init__`, we can initialize it if it is `None` on device
  2502. if self.flag_tensor is None:
  2503. self.flag_tensor = torch.tensor(0, device=self.device)
  2504. flag_tensor = self.reduce(self.flag_tensor)
  2505. if flag_tensor.item() >= 1:
  2506. self.flag_tensor = torch.tensor(0, device=self.device)
  2507. return True
  2508. return False
  2509. def unscale_gradients(self, optimizer=None):
  2510. """
  2511. Unscale the gradients in mixed precision training with AMP. This is a noop in all other settings.
  2512. Likely should be called through [`Accelerator.clip_grad_norm_`] or [`Accelerator.clip_grad_value_`]
  2513. Args:
  2514. optimizer (`torch.optim.Optimizer` or `list[torch.optim.Optimizer]`, *optional*):
  2515. The optimizer(s) for which to unscale gradients. If not set, will unscale gradients on all optimizers
  2516. that were passed to [`~Accelerator.prepare`].
  2517. Example:
  2518. ```python
  2519. >>> from accelerate import Accelerator
  2520. >>> accelerator = Accelerator()
  2521. >>> model, optimizer = accelerator.prepare(model, optimizer)
  2522. >>> outputs = model(inputs)
  2523. >>> loss = loss_fn(outputs, labels)
  2524. >>> accelerator.backward(loss)
  2525. >>> accelerator.unscale_gradients(optimizer=optimizer)
  2526. ```
  2527. """
  2528. if self.native_amp and self.mixed_precision == "fp16":
  2529. if optimizer is None:
  2530. # TODO: this unscales all optimizers where we should only unscale the one where parameters are.
  2531. optimizer = self._optimizers
  2532. elif not isinstance(optimizer, (tuple, list)):
  2533. optimizer = [optimizer]
  2534. for opt in optimizer:
  2535. while isinstance(opt, AcceleratedOptimizer):
  2536. opt = opt.optimizer
  2537. self.scaler.unscale_(opt)
  2538. def clip_grad_norm_(self, parameters, max_norm, norm_type=2):
  2539. """
  2540. Should be used in place of `torch.nn.utils.clip_grad_norm_`.
  2541. Returns:
  2542. `torch.Tensor`: Total norm of the parameter gradients (viewed as a single vector).
  2543. Example:
  2544. ```python
  2545. >>> from accelerate import Accelerator
  2546. >>> accelerator = Accelerator(gradient_accumulation_steps=2)
  2547. >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
  2548. >>> for input, target in dataloader:
  2549. ... optimizer.zero_grad()
  2550. ... output = model(input)
  2551. ... loss = loss_func(output, target)
  2552. ... accelerator.backward(loss)
  2553. ... if accelerator.sync_gradients:
  2554. ... accelerator.clip_grad_norm_(model.parameters(), max_grad_norm)
  2555. ... optimizer.step()
  2556. ```
  2557. """
  2558. if self.distributed_type == DistributedType.FSDP:
  2559. self.unscale_gradients()
  2560. parameters = [p for p in parameters]
  2561. for model in self._models:
  2562. if parameters == [p for p in model.parameters()]:
  2563. if not self.is_fsdp2:
  2564. return model.clip_grad_norm_(max_norm, norm_type)
  2565. else:
  2566. return torch.nn.utils.clip_grad_norm_(
  2567. parameters, max_norm, norm_type=norm_type
  2568. ) # viz: https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
  2569. elif self.distributed_type == DistributedType.DEEPSPEED:
  2570. # DeepSpeed handles gradient clipping internally, but we can retrieve the gradient norm
  2571. if self.deepspeed_engine_wrapped is not None:
  2572. return self.deepspeed_engine_wrapped.get_global_grad_norm()
  2573. return None
  2574. elif self.distributed_type == DistributedType.XLA:
  2575. # Reduce gradients first for XLA
  2576. for acc_opt in self._optimizers:
  2577. if not acc_opt.gradient_state.is_xla_gradients_synced:
  2578. opt = acc_opt
  2579. while isinstance(opt, AcceleratedOptimizer):
  2580. opt = opt.optimizer
  2581. gradients = xm._fetch_gradients(opt)
  2582. # Use xm.all_reduce to perform an in-place all-reduce. Recursive all-reduce each tensor
  2583. # one by one in self.reduce is non-inplace.
  2584. xm.all_reduce("sum", gradients, scale=1.0 / self.num_processes)
  2585. # Set is_xla_gradients_synced to True to avoid all-reduce twice in the AcceleratedOptimizer step.
  2586. acc_opt.gradient_state.is_xla_gradients_synced = True
  2587. if os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true":
  2588. self.unscale_gradients()
  2589. parameters = [p for p in parameters]
  2590. for model in self._models:
  2591. if parameters == [p for p in model.parameters()]:
  2592. return model.clip_grad_norm_(max_norm, norm_type)
  2593. self.unscale_gradients()
  2594. return torch.nn.utils.clip_grad_norm_(parameters, max_norm, norm_type=norm_type)
  2595. def clip_grad_value_(self, parameters, clip_value):
  2596. """
  2597. Should be used in place of `torch.nn.utils.clip_grad_value_`.
  2598. Example:
  2599. ```python
  2600. >>> from accelerate import Accelerator
  2601. >>> accelerator = Accelerator(gradient_accumulation_steps=2)
  2602. >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
  2603. >>> for input, target in dataloader:
  2604. ... optimizer.zero_grad()
  2605. ... output = model(input)
  2606. ... loss = loss_func(output, target)
  2607. ... accelerator.backward(loss)
  2608. ... if accelerator.sync_gradients:
  2609. ... accelerator.clip_grad_value_(model.parameters(), clip_value)
  2610. ... optimizer.step()
  2611. ```
  2612. """
  2613. if self.distributed_type in [DistributedType.DEEPSPEED, DistributedType.FSDP]:
  2614. raise Exception("DeepSpeed and FSDP do not support `clip_grad_value_`. Use `clip_grad_norm_` instead.")
  2615. self.unscale_gradients()
  2616. torch.nn.utils.clip_grad_value_(parameters, clip_value)
  2617. def gather(self, tensor):
  2618. """
  2619. Gather the values in *tensor* across all processes and concatenate them on the first dimension. Useful to
  2620. regroup the predictions from all processes when doing evaluation.
  2621. Note:
  2622. This gather happens in all processes.
  2623. Args:
  2624. tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
  2625. The tensors to gather across all processes.
  2626. Returns:
  2627. `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`: The gathered tensor(s). Note that the
  2628. first dimension of the result is *num_processes* multiplied by the first dimension of the input tensors.
  2629. Example:
  2630. ```python
  2631. >>> # Assuming four processes
  2632. >>> import torch
  2633. >>> from accelerate import Accelerator
  2634. >>> accelerator = Accelerator()
  2635. >>> process_tensor = torch.tensor([accelerator.process_index], device=accelerator.device)
  2636. >>> gathered_tensor = accelerator.gather(process_tensor)
  2637. >>> gathered_tensor
  2638. tensor([0, 1, 2, 3])
  2639. ```
  2640. """
  2641. return gather(tensor)
  2642. def gather_for_metrics(self, input_data, use_gather_object=False):
  2643. """
  2644. Gathers `input_data` and potentially drops duplicates in the last batch if on a distributed system. Should be
  2645. used for gathering the inputs and targets for metric calculation.
  2646. Args:
  2647. input (`torch.Tensor`, `object`, a nested tuple/list/dictionary of `torch.Tensor`, or a nested tuple/list/dictionary of `object`):
  2648. The tensors or objects for calculating metrics across all processes
  2649. use_gather_object(`bool`):
  2650. Whether to forcibly use gather_object instead of gather (which is already done if all objects passed do
  2651. not contain tensors). This flag can be useful for gathering tensors with different sizes that we don't
  2652. want to pad and concatenate along the first dimension. Using it with GPU tensors is not well supported
  2653. and inefficient as it incurs GPU -> CPU transfer since tensors would be pickled.
  2654. Example:
  2655. ```python
  2656. >>> # Assuming two processes, with a batch size of 5 on a dataset with 9 samples
  2657. >>> import torch
  2658. >>> from accelerate import Accelerator
  2659. >>> accelerator = Accelerator()
  2660. >>> dataloader = torch.utils.data.DataLoader(range(9), batch_size=5)
  2661. >>> dataloader = accelerator.prepare(dataloader)
  2662. >>> batch = next(iter(dataloader))
  2663. >>> gathered_items = accelerator.gather_for_metrics(batch)
  2664. >>> len(gathered_items)
  2665. 9
  2666. ```
  2667. """
  2668. try:
  2669. recursively_apply(lambda x: x, input_data, error_on_other_type=True)
  2670. all_tensors = True
  2671. except TypeError:
  2672. all_tensors = False
  2673. use_gather_object = use_gather_object or not all_tensors
  2674. if use_gather_object:
  2675. data = gather_object(input_data)
  2676. else:
  2677. data = self.gather(input_data)
  2678. try:
  2679. if self.gradient_state.end_of_dataloader:
  2680. # at the end of a dataloader, `gather_for_metrics` regresses to
  2681. # `gather` unless the dataset has a remainder so log.
  2682. if self.gradient_state.remainder == -1:
  2683. logger.info(
  2684. "The used dataset had no length, returning gathered tensors. You should drop the remainder yourself."
  2685. )
  2686. return data
  2687. elif self.gradient_state.remainder > 0:
  2688. # Last batch needs to be truncated on distributed systems as it contains additional samples
  2689. def _adjust_samples(tensor):
  2690. return tensor[: self.gradient_state.remainder]
  2691. if use_gather_object:
  2692. # gather_object put the objects in a list
  2693. return _adjust_samples(data)
  2694. else:
  2695. return recursively_apply(_adjust_samples, data)
  2696. else: # remainder is 0
  2697. # no remainder even though at end of dataloader, so nothing to do.
  2698. return data
  2699. else:
  2700. # Not at the end of the dataloader, no need to adjust the tensors
  2701. return data
  2702. except Exception:
  2703. # Dataset had no length or raised an error
  2704. return data
  2705. def reduce(self, tensor, reduction="sum", scale=1.0):
  2706. """
  2707. Reduce the values in *tensor* across all processes based on *reduction*.
  2708. Note:
  2709. All processes get the reduced value.
  2710. Args:
  2711. tensor (`torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`):
  2712. The tensors to reduce across all processes.
  2713. reduction (`str`, *optional*, defaults to "sum"):
  2714. A reduction type, can be one of 'sum', 'mean', or 'none'. If 'none', will not perform any operation.
  2715. scale (`float`, *optional*, defaults to 1.0):
  2716. A default scaling value to be applied after the reduce, only valid on XLA.
  2717. Returns:
  2718. `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`:
  2719. The reduced tensor(s).
  2720. Example:
  2721. ```python
  2722. >>> # Assuming two processes
  2723. >>> import torch
  2724. >>> from accelerate import Accelerator
  2725. >>> accelerator = Accelerator()
  2726. >>> process_tensor = torch.arange(accelerator.num_processes) + 1 + (2 * accelerator.process_index)
  2727. >>> process_tensor = process_tensor.to(accelerator.device)
  2728. >>> reduced_tensor = accelerator.reduce(process_tensor, reduction="sum")
  2729. >>> reduced_tensor
  2730. tensor([4, 6])
  2731. ```
  2732. """
  2733. return reduce(tensor, reduction, scale)
  2734. def pad_across_processes(self, tensor, dim=0, pad_index=0, pad_first=False):
  2735. """
  2736. Recursively pad the tensors in a nested list/tuple/dictionary of tensors from all devices to the same size so
  2737. they can safely be gathered.
  2738. Args:
  2739. tensor (nested list/tuple/dictionary of `torch.Tensor`):
  2740. The data to gather.
  2741. dim (`int`, *optional*, defaults to 0):
  2742. The dimension on which to pad.
  2743. pad_index (`int`, *optional*, defaults to 0):
  2744. The value with which to pad.
  2745. pad_first (`bool`, *optional*, defaults to `False`):
  2746. Whether to pad at the beginning or the end.
  2747. Returns:
  2748. `torch.Tensor`, or a nested tuple/list/dictionary of `torch.Tensor`:
  2749. The padded tensor(s).
  2750. Example:
  2751. ```python
  2752. >>> # Assuming two processes, with the first processes having a tensor of size 1 and the second of size 2
  2753. >>> import torch
  2754. >>> from accelerate import Accelerator
  2755. >>> accelerator = Accelerator()
  2756. >>> process_tensor = torch.arange(accelerator.process_index + 1).to(accelerator.device)
  2757. >>> padded_tensor = accelerator.pad_across_processes(process_tensor)
  2758. >>> padded_tensor.shape
  2759. torch.Size([2])
  2760. ```
  2761. """
  2762. return pad_across_processes(tensor, dim=dim, pad_index=pad_index, pad_first=pad_first)
  2763. def unwrap_model(self, model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True):
  2764. """
  2765. Unwraps the `model` from the additional layer possible added by [`~Accelerator.prepare`]. Useful before saving
  2766. the model.
  2767. Args:
  2768. model (`torch.nn.Module`):
  2769. The model to unwrap.
  2770. keep_fp32_wrapper (`bool`, *optional*, defaults to `True`):
  2771. Whether to not remove the mixed precision hook if it was added.
  2772. keep_torch_compile (`bool`, *optional*, defaults to `True`):
  2773. Whether to not unwrap compiled model if compiled.
  2774. Returns:
  2775. `torch.nn.Module`: The unwrapped model.
  2776. Example:
  2777. ```python
  2778. >>> # Assuming two GPU processes
  2779. >>> from torch.nn.parallel import DistributedDataParallel
  2780. >>> from accelerate import Accelerator
  2781. >>> accelerator = Accelerator()
  2782. >>> model = accelerator.prepare(MyModel())
  2783. >>> print(model.__class__.__name__)
  2784. DistributedDataParallel
  2785. >>> model = accelerator.unwrap_model(model)
  2786. >>> print(model.__class__.__name__)
  2787. MyModel
  2788. ```
  2789. """
  2790. return extract_model_from_parallel(model, keep_fp32_wrapper, keep_torch_compile)
  2791. def wait_for_everyone(self):
  2792. """
  2793. Will stop the execution of the current process until every other process has reached that point (so this does
  2794. nothing when the script is only run in one process). Useful to do before saving a model.
  2795. Example:
  2796. ```python
  2797. >>> # Assuming two GPU processes
  2798. >>> import time
  2799. >>> from accelerate import Accelerator
  2800. >>> accelerator = Accelerator()
  2801. >>> if accelerator.is_main_process:
  2802. ... time.sleep(2)
  2803. >>> else:
  2804. ... print("I'm waiting for the main process to finish its sleep...")
  2805. >>> accelerator.wait_for_everyone()
  2806. >>> # Should print on every process at the same time
  2807. >>> print("Everyone is here")
  2808. ```
  2809. """
  2810. wait_for_everyone()
  2811. @on_main_process
  2812. def init_trackers(self, project_name: str, config: dict | None = None, init_kwargs: dict | None = {}):
  2813. """
  2814. Initializes a run for all trackers stored in `self.log_with`, potentially with starting configurations
  2815. Args:
  2816. project_name (`str`):
  2817. The name of the project. All trackers will save their data based on this
  2818. config (`dict`, *optional*):
  2819. Optional starting configuration to be logged.
  2820. init_kwargs (`dict`, *optional*):
  2821. A nested dictionary of kwargs to be passed to a specific tracker's `__init__` function. Should be
  2822. formatted like so:
  2823. ```python
  2824. {"wandb": {"tags": ["tag_a", "tag_b"]}}
  2825. ```
  2826. Example:
  2827. ```python
  2828. >>> from accelerate import Accelerator
  2829. >>> accelerator = Accelerator(log_with="tensorboard")
  2830. >>> accelerator.init_trackers(
  2831. ... project_name="my_project",
  2832. ... config={"learning_rate": 0.001, "batch_size": 32},
  2833. ... init_kwargs={"tensorboard": {"flush_secs": 60}},
  2834. ... )
  2835. ```
  2836. """
  2837. for tracker in self.log_with:
  2838. if issubclass(type(tracker), GeneralTracker):
  2839. # Custom trackers are already initialized
  2840. self.trackers.append(tracker)
  2841. else:
  2842. tracker_init = LOGGER_TYPE_TO_CLASS[str(tracker)]
  2843. if tracker_init.requires_logging_directory:
  2844. # We can skip this check since it was done in `__init__`
  2845. self.trackers.append(
  2846. tracker_init(project_name, self.logging_dir, **init_kwargs.get(str(tracker), {}))
  2847. )
  2848. else:
  2849. self.trackers.append(tracker_init(project_name, **init_kwargs.get(str(tracker), {})))
  2850. for tracker in self.trackers:
  2851. tracker.start()
  2852. if config is not None:
  2853. for tracker in self.trackers:
  2854. tracker.store_init_configuration(config)
  2855. def get_tracker(self, name: str, unwrap: bool = False):
  2856. """
  2857. Returns a `tracker` from `self.trackers` based on `name` on the main process only.
  2858. Args:
  2859. name (`str`):
  2860. The name of a tracker, corresponding to the `.name` property.
  2861. unwrap (`bool`):
  2862. Whether to return the internal tracking mechanism or to return the wrapped tracker instead
  2863. (recommended).
  2864. Returns:
  2865. `GeneralTracker`: The tracker corresponding to `name` if it exists.
  2866. Example:
  2867. ```python
  2868. >>> from accelerate import Accelerator
  2869. >>> accelerator = Accelerator(log_with="tensorboard")
  2870. >>> accelerator.init_trackers("my_project")
  2871. >>> tensorboard_tracker = accelerator.get_tracker("tensorboard")
  2872. ```
  2873. """
  2874. if len(self.trackers) > 0:
  2875. for tracker in self.trackers:
  2876. if tracker.name == name:
  2877. return tracker.tracker if unwrap else tracker
  2878. raise ValueError(f"{name} is not an available tracker stored inside the `Accelerator`.")
  2879. # Handle tracker only made on main process
  2880. return GeneralTracker(_blank=True)
  2881. @on_main_process
  2882. def log(self, values: dict, step: int | None = None, log_kwargs: dict | None = {}):
  2883. """
  2884. Logs `values` to all stored trackers in `self.trackers` on the main process only.
  2885. Args:
  2886. values (`dict`):
  2887. Values should be a dictionary-like object containing only types `int`, `float`, or `str`.
  2888. step (`int`, *optional*):
  2889. The run step. If included, the log will be affiliated with this step.
  2890. log_kwargs (`dict`, *optional*):
  2891. A nested dictionary of kwargs to be passed to a specific tracker's `log` function. Should be formatted
  2892. like so:
  2893. ```python
  2894. {"wandb": {"tags": ["tag_a", "tag_b"]}}
  2895. ```
  2896. Example:
  2897. ```python
  2898. >>> from accelerate import Accelerator
  2899. >>> accelerator = Accelerator(log_with="tensorboard")
  2900. >>> accelerator.init_trackers("my_project")
  2901. >>> accelerator.log({"loss": 0.5, "accuracy": 0.9})
  2902. ```
  2903. """
  2904. for tracker in self.trackers:
  2905. tracker.log(values, step=step, **log_kwargs.get(tracker.name, {}))
  2906. def end_training(self):
  2907. """
  2908. Runs any special end training behaviors, such as stopping trackers on the main process only or destoying
  2909. process group. Should always be called at the end of your script if using experiment tracking.
  2910. Example:
  2911. ```python
  2912. >>> from accelerate import Accelerator
  2913. >>> accelerator = Accelerator(log_with="tensorboard")
  2914. >>> accelerator.init_trackers("my_project")
  2915. >>> # Do training
  2916. >>> accelerator.end_training()
  2917. ```
  2918. """
  2919. for tracker in self.trackers:
  2920. tracker.finish()
  2921. self.state.destroy_process_group()
  2922. def save(self, obj, f, safe_serialization=False):
  2923. """
  2924. Save the object passed to disk once per machine. Use in place of `torch.save`.
  2925. Args:
  2926. obj (`object`): The object to save.
  2927. f (`str` or `os.PathLike`): Where to save the content of `obj`.
  2928. safe_serialization (`bool`, *optional*, defaults to `False`): Whether to save `obj` using `safetensors`
  2929. Note:
  2930. If `save_on_each_node` was passed in as a `ProjectConfiguration`, will save the object once per node,
  2931. rather than only once on the main node.
  2932. Example:
  2933. ```python
  2934. >>> from accelerate import Accelerator
  2935. >>> accelerator = Accelerator()
  2936. >>> arr = [0, 1, 2, 3]
  2937. >>> accelerator.save(arr, "array.pkl")
  2938. ```
  2939. """
  2940. save(
  2941. obj,
  2942. f,
  2943. save_on_each_node=self.project_configuration.save_on_each_node,
  2944. safe_serialization=safe_serialization,
  2945. )
  2946. def save_model(
  2947. self,
  2948. model: torch.nn.Module,
  2949. save_directory: Union[str, os.PathLike],
  2950. max_shard_size: Union[int, str] = "10GB",
  2951. safe_serialization: bool = True,
  2952. ):
  2953. """
  2954. Save a model so that it can be re-loaded using load_checkpoint_in_model
  2955. Arguments:
  2956. model: (`torch.nn.Module`):
  2957. Model to be saved. The model can be wrapped or unwrapped.
  2958. save_directory (`str` or `os.PathLike`):
  2959. Directory to which to save. Will be created if it doesn't exist.
  2960. max_shard_size (`int` or `str`, *optional*, defaults to `"10GB"`):
  2961. The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
  2962. lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5MB"`).
  2963. <Tip warning={true}>
  2964. If a single weight of the model is bigger than `max_shard_size`, it will be in its own checkpoint shard
  2965. which will be bigger than `max_shard_size`.
  2966. </Tip>
  2967. safe_serialization (`bool`, *optional*, defaults to `True`):
  2968. Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
  2969. Example:
  2970. ```python
  2971. >>> from accelerate import Accelerator
  2972. >>> accelerator = Accelerator()
  2973. >>> model = ...
  2974. >>> accelerator.save_model(model, save_directory)
  2975. ```
  2976. """
  2977. if os.path.isfile(save_directory):
  2978. logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
  2979. return
  2980. # get the state_dict of the model
  2981. if any(has_offloaded_params(module) for module in model.modules()):
  2982. state_dict = get_state_dict_offloaded_model(model)
  2983. else:
  2984. if any(param.device == torch.device("meta") for param in model.parameters()):
  2985. raise RuntimeError("You can't save the model since some parameters are on the meta device.")
  2986. state_dict = self.get_state_dict(model)
  2987. # Case: DeepSpeed zero3 gets gathered and `state_dict` is empty
  2988. if state_dict is None:
  2989. return
  2990. os.makedirs(save_directory, exist_ok=True)
  2991. if safe_serialization:
  2992. state_dict = clean_state_dict_for_safetensors(state_dict)
  2993. weights_name = SAFE_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
  2994. filename_pattern = SAFE_WEIGHTS_PATTERN_NAME if safe_serialization else WEIGHTS_PATTERN_NAME
  2995. state_dict_split = split_torch_state_dict_into_shards(
  2996. state_dict, filename_pattern=filename_pattern, max_shard_size=max_shard_size
  2997. )
  2998. # Clean the folder from a previous save
  2999. for filename in os.listdir(save_directory):
  3000. full_filename = os.path.join(save_directory, filename)
  3001. # If we have a shard file that is not going to be replaced, we delete it, but only from the main process
  3002. # in distributed settings to avoid race conditions.
  3003. weights_no_suffix = weights_name.replace(".bin", "")
  3004. # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
  3005. filename_no_suffix = filename.replace(".bin", "")
  3006. reg = re.compile(r"(.*?)-\d{5}-of-\d{5}")
  3007. if (
  3008. filename.startswith(weights_no_suffix)
  3009. and os.path.isfile(full_filename)
  3010. and filename not in state_dict_split.filename_to_tensors.keys()
  3011. and reg.fullmatch(filename_no_suffix) is not None
  3012. and PartialState().is_main_process
  3013. ):
  3014. os.remove(full_filename)
  3015. # Save the model
  3016. for filename, tensors in state_dict_split.filename_to_tensors.items():
  3017. shard = {tensor: state_dict[tensor] for tensor in tensors}
  3018. self.save(shard, os.path.join(save_directory, filename), safe_serialization=safe_serialization)
  3019. # Save index if sharded
  3020. if state_dict_split.is_sharded:
  3021. index = {
  3022. "metadata": state_dict_split.metadata,
  3023. "weight_map": state_dict_split.tensor_to_filename,
  3024. }
  3025. save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
  3026. save_index_file = os.path.join(save_directory, save_index_file)
  3027. with open(save_index_file, "w", encoding="utf-8") as f:
  3028. content = json.dumps(index, indent=2, sort_keys=True) + "\n"
  3029. f.write(content)
  3030. logger.info(
  3031. f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
  3032. f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
  3033. f"index located at {save_index_file}."
  3034. )
  3035. else:
  3036. path_to_weights = os.path.join(save_directory, WEIGHTS_NAME)
  3037. logger.info(f"Model weights saved in {path_to_weights}")
  3038. def register_save_state_pre_hook(self, hook: Callable[..., None]) -> hooks.RemovableHandle:
  3039. """
  3040. Registers a pre hook to be run before `save_checkpoint` is called in [`Accelerator.save_state`].
  3041. Args:
  3042. hook (`Callable`):
  3043. A function to be called in [`Accelerator.save_state`] before `save_checkpoint`.
  3044. The hook should have the following signature:
  3045. `hook(models: list[torch.nn.Module], weights: list[dict[str, torch.Tensor]], input_dir: str) -> None`
  3046. The `models` argument are the models as saved in the accelerator state under `accelerator._models`, `weights`
  3047. argument are the state dicts of the `models`, and the `input_dir` argument is the `input_dir` argument passed
  3048. to [`Accelerator.load_state`].
  3049. <Tip>
  3050. Should only be used in conjunction with [`Accelerator.register_load_state_pre_hook`]. Can be useful to save
  3051. configurations in addition to model weights. Can also be used to overwrite model saving with a customized
  3052. method. In this case, make sure to remove already loaded weights from the weights list.
  3053. </Tip>
  3054. Returns:
  3055. `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling
  3056. `handle.remove()`
  3057. """
  3058. handle = hooks.RemovableHandle(self._save_model_state_pre_hook)
  3059. self._save_model_state_pre_hook[handle.id] = hook
  3060. return handle
  3061. def save_state(self, output_dir: str | None = None, safe_serialization: bool = True, **save_model_func_kwargs):
  3062. """
  3063. Saves the current states of the model, optimizer, scaler, RNG generators, and registered objects to a folder.
  3064. If a `ProjectConfiguration` was passed to the `Accelerator` object with `automatic_checkpoint_naming` enabled
  3065. then checkpoints will be saved to `self.project_dir/checkpoints`. If the number of current saves is greater
  3066. than `total_limit` then the oldest save is deleted. Each checkpoint is saved in separate folders named
  3067. `checkpoint_<iteration>`.
  3068. Otherwise they are just saved to `output_dir`.
  3069. <Tip>
  3070. Should only be used when wanting to save a checkpoint during training and restoring the state in the same
  3071. environment.
  3072. </Tip>
  3073. Args:
  3074. output_dir (`str` or `os.PathLike`):
  3075. The name of the folder to save all relevant weights and states.
  3076. safe_serialization (`bool`, *optional*, defaults to `True`):
  3077. Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
  3078. save_model_func_kwargs (`dict`, *optional*):
  3079. Additional keyword arguments for saving model which can be passed to the underlying save function, such
  3080. as optional arguments for DeepSpeed's `save_checkpoint` function.
  3081. Example:
  3082. ```python
  3083. >>> from accelerate import Accelerator
  3084. >>> accelerator = Accelerator()
  3085. >>> model, optimizer, lr_scheduler = ...
  3086. >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
  3087. >>> accelerator.save_state(output_dir="my_checkpoint")
  3088. ```
  3089. """
  3090. if self.project_configuration.automatic_checkpoint_naming:
  3091. output_dir = os.path.join(self.project_dir, "checkpoints")
  3092. os.makedirs(output_dir, exist_ok=True)
  3093. if self.project_configuration.automatic_checkpoint_naming:
  3094. folders = [os.path.join(output_dir, folder) for folder in os.listdir(output_dir)]
  3095. if (
  3096. self.project_configuration.total_limit is not None
  3097. and (len(folders) + 1 > self.project_configuration.total_limit)
  3098. and self.is_main_process
  3099. ):
  3100. def _inner(folder):
  3101. return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
  3102. folders.sort(key=_inner)
  3103. logger.warning(
  3104. f"Deleting {len(folders) + 1 - self.project_configuration.total_limit} checkpoints to make room for new checkpoint."
  3105. )
  3106. for folder in folders[: len(folders) + 1 - self.project_configuration.total_limit]:
  3107. shutil.rmtree(folder)
  3108. output_dir = os.path.join(output_dir, f"checkpoint_{self.save_iteration}")
  3109. if os.path.exists(output_dir):
  3110. raise ValueError(
  3111. f"Checkpoint directory {output_dir} ({self.save_iteration}) already exists. Please manually override `self.save_iteration` with what iteration to start with."
  3112. )
  3113. self.wait_for_everyone()
  3114. os.makedirs(output_dir, exist_ok=True)
  3115. logger.info(f"Saving current state to {output_dir}")
  3116. if self.distributed_type == DistributedType.XLA:
  3117. # Finish running the previous step before checkpointing
  3118. xm.mark_step()
  3119. # Save the models taking care of FSDP and DeepSpeed nuances
  3120. weights = []
  3121. for i, model in enumerate(self._models):
  3122. if self.distributed_type == DistributedType.FSDP:
  3123. logger.info("Saving FSDP model")
  3124. save_fsdp_model(self.state.fsdp_plugin, self, model, output_dir, i)
  3125. logger.info(f"FSDP Model saved to output dir {output_dir}")
  3126. elif self.distributed_type == DistributedType.DEEPSPEED:
  3127. logger.info("Saving DeepSpeed Model and Optimizer")
  3128. ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
  3129. model.save_checkpoint(output_dir, ckpt_id, **save_model_func_kwargs)
  3130. logger.info(f"DeepSpeed Model and Optimizer saved to output dir {os.path.join(output_dir, ckpt_id)}")
  3131. elif self.distributed_type == DistributedType.MEGATRON_LM:
  3132. logger.info("Saving Megatron-LM Model, Optimizer and Scheduler")
  3133. model.save_checkpoint(output_dir)
  3134. logger.info(f"Megatron-LM Model , Optimizer and Scheduler saved to output dir {output_dir}")
  3135. else:
  3136. weights.append(self.get_state_dict(model, unwrap=False))
  3137. # Save the optimizers taking care of FSDP and DeepSpeed nuances
  3138. optimizers = []
  3139. if self.distributed_type == DistributedType.FSDP:
  3140. for i, opt in enumerate(self._optimizers):
  3141. logger.info("Saving FSDP Optimizer")
  3142. save_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], output_dir, i)
  3143. logger.info(f"FSDP Optimizer saved to output dir {output_dir}")
  3144. elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
  3145. optimizers = self._optimizers
  3146. # Save the lr schedulers taking care of DeepSpeed nuances
  3147. schedulers = []
  3148. if self.distributed_type == DistributedType.DEEPSPEED:
  3149. for i, scheduler in enumerate(self._schedulers):
  3150. if isinstance(scheduler, DeepSpeedSchedulerWrapper):
  3151. continue
  3152. schedulers.append(scheduler)
  3153. elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
  3154. schedulers = self._schedulers
  3155. # Save the samplers of the dataloaders
  3156. dataloaders = self._dataloaders
  3157. # Call model loading hooks that might have been registered with
  3158. # accelerator.register_model_state_hook
  3159. for hook in self._save_model_state_pre_hook.values():
  3160. hook(self._models, weights, output_dir)
  3161. save_location = save_accelerator_state(
  3162. output_dir,
  3163. weights,
  3164. optimizers,
  3165. schedulers,
  3166. dataloaders,
  3167. self.state.process_index,
  3168. self.step,
  3169. self.scaler,
  3170. save_on_each_node=self.project_configuration.save_on_each_node,
  3171. safe_serialization=safe_serialization,
  3172. )
  3173. for i, obj in enumerate(self._custom_objects):
  3174. save_custom_state(obj, output_dir, i, save_on_each_node=self.project_configuration.save_on_each_node)
  3175. self.project_configuration.iteration += 1
  3176. return save_location
  3177. def register_load_state_pre_hook(self, hook: Callable[..., None]) -> hooks.RemovableHandle:
  3178. """
  3179. Registers a pre hook to be run before [`load_checkpoint`] is called in [`Accelerator.load_state`].
  3180. Args:
  3181. hook (`Callable`):
  3182. A function to be called in [`Accelerator.load_state`] before `load_checkpoint`.
  3183. The hook should have the following signature:
  3184. `hook(models: list[torch.nn.Module], input_dir: str) -> None`
  3185. The `models` argument are the models as saved in the accelerator state under `accelerator._models`, and the
  3186. `input_dir` argument is the `input_dir` argument passed to [`Accelerator.load_state`].
  3187. <Tip>
  3188. Should only be used in conjunction with [`Accelerator.register_save_state_pre_hook`]. Can be useful to load
  3189. configurations in addition to model weights. Can also be used to overwrite model loading with a customized
  3190. method. In this case, make sure to remove already loaded models from the models list.
  3191. </Tip>
  3192. Returns:
  3193. `torch.utils.hooks.RemovableHandle`: a handle that can be used to remove the added hook by calling
  3194. `handle.remove()`
  3195. """
  3196. handle = hooks.RemovableHandle(self._load_model_state_pre_hook)
  3197. self._load_model_state_pre_hook[handle.id] = hook
  3198. return handle
  3199. def load_state(self, input_dir: str | None = None, load_kwargs: dict | None = None, **load_model_func_kwargs):
  3200. """
  3201. Loads the current states of the model, optimizer, scaler, RNG generators, and registered objects.
  3202. <Tip>
  3203. Should only be used in conjunction with [`Accelerator.save_state`]. If a file is not registered for
  3204. checkpointing, it will not be loaded if stored in the directory.
  3205. </Tip>
  3206. Args:
  3207. input_dir (`str` or `os.PathLike`):
  3208. The name of the folder all relevant weights and states were saved in. Can be `None` if
  3209. `automatic_checkpoint_naming` is used, and will pick up from the latest checkpoint.
  3210. load_kwargs (`dict`, *optional*):
  3211. Additional keyword arguments for the underlying `load` function, such as optional arguments for
  3212. state_dict and optimizer on.
  3213. load_model_func_kwargs (`dict`, *optional*):
  3214. Additional keyword arguments for loading model which can be passed to the underlying load function,
  3215. such as optional arguments for DeepSpeed's `load_checkpoint` function or a `map_location` to load the
  3216. model and optimizer on.
  3217. Example:
  3218. ```python
  3219. >>> from accelerate import Accelerator
  3220. >>> accelerator = Accelerator()
  3221. >>> model, optimizer, lr_scheduler = ...
  3222. >>> model, optimizer, lr_scheduler = accelerator.prepare(model, optimizer, lr_scheduler)
  3223. >>> accelerator.load_state("my_checkpoint")
  3224. ```
  3225. """
  3226. if input_dir is not None:
  3227. # Check if folder exists
  3228. input_dir = os.path.expanduser(input_dir)
  3229. if not os.path.isdir(input_dir):
  3230. raise ValueError(f"Tried to find {input_dir} but folder does not exist")
  3231. elif self.project_configuration.automatic_checkpoint_naming:
  3232. # Pick up from automatic checkpoint naming
  3233. input_dir = os.path.join(self.project_dir, "checkpoints")
  3234. folders = [os.path.join(input_dir, folder) for folder in os.listdir(input_dir)]
  3235. def _inner(folder):
  3236. return list(map(int, re.findall(r"[\/]?([0-9]+)(?=[^\/]*$)", folder)))[0]
  3237. folders.sort(key=_inner)
  3238. input_dir = folders[-1]
  3239. else:
  3240. raise ValueError("No input_dir provided and automatic checkpoint naming is disabled.")
  3241. logger.info(f"Loading states from {input_dir}")
  3242. # Load the models taking care of FSDP and DeepSpeed nuances
  3243. models = []
  3244. for i, model in enumerate(self._models):
  3245. if self.distributed_type == DistributedType.FSDP:
  3246. logger.info("Loading FSDP model")
  3247. load_fsdp_model(self.state.fsdp_plugin, self, model, input_dir, i)
  3248. logger.info(f"FSDP Model loaded from input dir {input_dir}")
  3249. elif self.distributed_type == DistributedType.DEEPSPEED:
  3250. logger.info("Loading DeepSpeed Model and Optimizer")
  3251. ckpt_id = f"{MODEL_NAME}" if i == 0 else f"{MODEL_NAME}_{i}"
  3252. model.load_checkpoint(input_dir, ckpt_id, **load_model_func_kwargs)
  3253. logger.info(f"DeepSpeed Model and Optimizer loaded from input dir {os.path.join(input_dir, ckpt_id)}")
  3254. elif self.distributed_type == DistributedType.MEGATRON_LM:
  3255. logger.info("Loading Megatron-LM Model, Optimizer and Scheduler")
  3256. model.load_checkpoint(input_dir)
  3257. logger.info(f"Megatron-LM Model , Optimizer and Scheduler loaded from input dir {input_dir}")
  3258. else:
  3259. models.append(model)
  3260. # We need to load the scaler state before the optimizer for FSDP2
  3261. # (`torch.distributed.checkpoint.set_optimizer_state_dict`) which we use to set the state of the optimizer calls `optimizer.step` on
  3262. # a dummy tensor, but since the scaler is not initialized, it will raise an error (the scaler exists but its `_scale` is None)
  3263. scaler = None
  3264. if self.scaler is not None and self.is_fsdp2:
  3265. input_scaler_file = os.path.join(input_dir, SCALER_NAME)
  3266. scaler_state = torch.load(input_scaler_file)
  3267. self.scaler.load_state_dict(scaler_state)
  3268. # We also need to call the `_lazy_init_scale_growth_tracker` to initialize the scaler, as it would else be called
  3269. # on the first call to scale
  3270. self.scaler._lazy_init_scale_growth_tracker(self.scaler._device)
  3271. logger.info("GradScaler state loaded successfully")
  3272. else:
  3273. scaler = self.scaler
  3274. # Load the optimizers taking care of FSDP and DeepSpeed nuances
  3275. optimizers = []
  3276. if self.distributed_type == DistributedType.FSDP:
  3277. for i, opt in enumerate(self._optimizers):
  3278. logger.info("Loading FSDP Optimizer")
  3279. load_fsdp_optimizer(self.state.fsdp_plugin, self, opt, self._models[i], input_dir, i)
  3280. logger.info(f"FSDP Optimizer loaded from input dir {input_dir}")
  3281. elif self.distributed_type not in [DistributedType.DEEPSPEED, DistributedType.MEGATRON_LM]:
  3282. optimizers = self._optimizers
  3283. # Load the lr schedulers taking care of DeepSpeed nuances
  3284. schedulers = []
  3285. if self.distributed_type == DistributedType.DEEPSPEED:
  3286. for i, scheduler in enumerate(self._schedulers):
  3287. if isinstance(scheduler, DeepSpeedSchedulerWrapper):
  3288. continue
  3289. schedulers.append(scheduler)
  3290. elif self.distributed_type not in [DistributedType.MEGATRON_LM]:
  3291. schedulers = self._schedulers
  3292. dataloaders = self._dataloaders
  3293. # Call model loading hooks that might have been registered with
  3294. # accelerator.register_model_state_hook
  3295. for hook in self._load_model_state_pre_hook.values():
  3296. hook(models, input_dir)
  3297. map_location = load_model_func_kwargs.pop("map_location", None)
  3298. if map_location is None:
  3299. if self.num_processes > 1 and self.multi_device and self.distributed_type != DistributedType.MULTI_XPU:
  3300. map_location = "on_device"
  3301. else:
  3302. map_location = "cpu"
  3303. override_attributes = load_accelerator_state(
  3304. input_dir,
  3305. models,
  3306. optimizers,
  3307. schedulers,
  3308. dataloaders,
  3309. self.state.process_index,
  3310. scaler,
  3311. map_location,
  3312. load_kwargs,
  3313. **load_model_func_kwargs,
  3314. )
  3315. if "step" in override_attributes:
  3316. self.step = override_attributes["step"]
  3317. custom_checkpoints = [
  3318. f for f in os.listdir(input_dir) if re.search(r"^custom_checkpoint_\d+\.pkl$", f) is not None
  3319. ]
  3320. if len(custom_checkpoints) != len(self._custom_objects):
  3321. err = (
  3322. f"Number of custom checkpoints in folder {input_dir} does not match the number of registered objects:"
  3323. )
  3324. err += f"\n\tFound checkpoints: {len(custom_checkpoints)}"
  3325. err += f"\n\tRegistered objects: {len(self._custom_objects)}\n"
  3326. err += "Please make sure to only load checkpoints from folders that were created with the same set of registered objects,"
  3327. err += "or avoid using `custom_checkpoint` in the filename for files in that same directory and load them in manually."
  3328. raise RuntimeError(err)
  3329. else:
  3330. logger.info(f"Loading in {len(custom_checkpoints)} custom states")
  3331. for index, obj in enumerate(self._custom_objects):
  3332. load_custom_state(obj, input_dir, index)
  3333. def free_memory(self, *objects):
  3334. """
  3335. Will release all references to the internal objects stored and call the garbage collector. You should call this
  3336. method between two trainings with different models/optimizers. Also will reset `Accelerator.step` to 0.
  3337. Example:
  3338. ```python
  3339. >>> from accelerate import Accelerator
  3340. >>> accelerator = Accelerator()
  3341. >>> model, optimizer, scheduler = ...
  3342. >>> model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
  3343. >>> model, optimizer, scheduler = accelerator.free_memory(model, optimizer, scheduler)
  3344. ```
  3345. """
  3346. # Deepspeed needs a bit more prep that should be done first
  3347. if hasattr(self, "deepspeed_engine_wrapped"):
  3348. if self.deepspeed_engine_wrapped is not None:
  3349. self.deepspeed_engine_wrapped.engine.destroy()
  3350. self.deepspeed_engine_wrapped = None
  3351. objects = release_memory(*objects)
  3352. self._schedulers = []
  3353. self._optimizers = []
  3354. self._models = []
  3355. self._dataloaders = []
  3356. self.step = 0
  3357. return objects
  3358. def clear(self, *objects):
  3359. """
  3360. Alias for [`Accelerate.free_memory`], releases all references to the internal objects stored and call the
  3361. garbage collector. You should call this method between two trainings with different models/optimizers.
  3362. Example:
  3363. ```python
  3364. >>> from accelerate import Accelerator
  3365. >>> accelerator = Accelerator()
  3366. >>> model, optimizer, scheduler = ...
  3367. >>> model, optimizer, scheduler = accelerator.prepare(model, optimizer, scheduler)
  3368. >>> model, optimizer, scheduler = accelerator.clear(model, optimizer, scheduler)
  3369. ```
  3370. """
  3371. return self.free_memory(*objects)
  3372. def _get_named_parameters(self, *args, drop_refs=False):
  3373. named_parameters = {}
  3374. accessor_mapping = {}
  3375. for obj in args:
  3376. if isinstance(obj, torch.nn.Module):
  3377. obj = extract_model_from_parallel(obj)
  3378. if not drop_refs:
  3379. named_parameters.update({n: p for n, p in obj.named_parameters()})
  3380. continue
  3381. # we need this bit as `WeightWithDynamic...` returns 0 when `data_ptr()` is called,
  3382. # the underlying pointer is actually hidden in `_tensor` attribute
  3383. if self.fp8_backend == FP8BackendType.AO:
  3384. from torchao.float8.fsdp_utils import WeightWithDynamicFloat8CastTensor
  3385. accessor_mapping[WeightWithDynamicFloat8CastTensor] = "_tensor"
  3386. # we know we're in FSDP2 so DTensor is available
  3387. if self.is_fsdp2:
  3388. from torch.distributed.tensor import DTensor
  3389. accessor_mapping[DTensor] = "_local_tensor"
  3390. named_parameters.update(
  3391. {
  3392. n: getattr(p, accessor_mapping[type(p)]).data_ptr()
  3393. if type(p) in accessor_mapping
  3394. else p.data_ptr()
  3395. for n, p in obj.named_parameters()
  3396. }
  3397. )
  3398. return named_parameters
  3399. def _get_devices(self, *args):
  3400. model_device = None
  3401. optimizer_device = None
  3402. for obj in args:
  3403. # Loop through model parameters and stop at the first once we have its device.
  3404. if isinstance(obj, torch.nn.Module):
  3405. for param in obj.parameters():
  3406. model_device = param.device
  3407. break
  3408. # Loop through optimizer parameters groups and stop at the first once we have its device.
  3409. if isinstance(obj, torch.optim.Optimizer):
  3410. for param_group in obj.param_groups:
  3411. if len(param_group["params"]) > 0:
  3412. optimizer_device = param_group["params"][0].device
  3413. break
  3414. return (model_device, optimizer_device)
  3415. def get_state_dict(self, model, unwrap=True):
  3416. """
  3417. Returns the state dictionary of a model sent through [`Accelerator.prepare`] potentially without full
  3418. precision.
  3419. Args:
  3420. model (`torch.nn.Module`):
  3421. A PyTorch model sent through [`Accelerator.prepare`]
  3422. unwrap (`bool`, *optional*, defaults to `True`):
  3423. Whether to return the original underlying state_dict of `model` or to return the wrapped state_dict
  3424. Returns:
  3425. `dict`: The state dictionary of the model potentially without full precision.
  3426. Example:
  3427. ```python
  3428. >>> import torch
  3429. >>> from accelerate import Accelerator
  3430. >>> accelerator = Accelerator()
  3431. >>> net = torch.nn.Linear(2, 2)
  3432. >>> net = accelerator.prepare(net)
  3433. >>> state_dict = accelerator.get_state_dict(net)
  3434. ```
  3435. """
  3436. if self.distributed_type == DistributedType.DEEPSPEED:
  3437. zero3_sharding = self.deepspeed_config["zero_optimization"]["stage"] == 3
  3438. tp_sharding = self.deepspeed_config.get("tensor_parallel", {}).get("autotp_size", 0) > 1
  3439. if zero3_sharding or tp_sharding:
  3440. if model.zero_gather_16bit_weights_on_model_save():
  3441. ver_min_required = "0.16.4"
  3442. if tp_sharding and not compare_versions("deepspeed", ">=", ver_min_required):
  3443. raise ImportError(
  3444. f"Deepspeed TP requires deepspeed>={ver_min_required}. Please update DeepSpeed via `pip install deepspeed -U`."
  3445. )
  3446. state_dict = (
  3447. model._consolidated_16bit_state_dict()
  3448. if tp_sharding
  3449. else model._zero3_consolidated_16bit_state_dict()
  3450. )
  3451. else:
  3452. raise ValueError(
  3453. "Cannot get 16bit model weights because `stage3_gather_16bit_weights_on_model_save` in DeepSpeed config is False. "
  3454. "To save the model weights in 16bit, set `stage3_gather_16bit_weights_on_model_save` to True in DeepSpeed config file or "
  3455. "set `zero3_save_16bit_model` to True when using `accelerate config`. "
  3456. "To save the full checkpoint, run `model.save_checkpoint(save_dir)` and use `zero_to_fp32.py` to recover weights."
  3457. )
  3458. else:
  3459. from deepspeed.checkpoint.utils import clone_tensors_for_torch_save
  3460. state_dict = clone_tensors_for_torch_save(self.unwrap_model(model).state_dict())
  3461. elif self.is_fsdp2:
  3462. from torch.distributed.checkpoint.state_dict import StateDictOptions, get_model_state_dict
  3463. options = StateDictOptions(full_state_dict=True, broadcast_from_rank0=True, cpu_offload=True)
  3464. state_dict = get_model_state_dict(model, options=options)
  3465. elif self.distributed_type == DistributedType.FSDP:
  3466. from torch.distributed.fsdp import FullStateDictConfig, StateDictType
  3467. from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  3468. full_state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
  3469. with FSDP.state_dict_type(model, StateDictType.FULL_STATE_DICT, full_state_dict_config):
  3470. state_dict = model.state_dict()
  3471. else:
  3472. if unwrap:
  3473. model = self.unwrap_model(model)
  3474. state_dict = model.state_dict()
  3475. return state_dict
  3476. def register_for_checkpointing(self, *objects):
  3477. """
  3478. Makes note of `objects` and will save or load them in during `save_state` or `load_state`.
  3479. These should be utilized when the state is being loaded or saved in the same script. It is not designed to be
  3480. used in different scripts.
  3481. <Tip>
  3482. Every `object` must have a `load_state_dict` and `state_dict` function to be stored.
  3483. </Tip>
  3484. Example:
  3485. ```python
  3486. >>> from accelerate import Accelerator
  3487. >>> accelerator = Accelerator()
  3488. >>> # Assume `CustomObject` has a `state_dict` and `load_state_dict` function.
  3489. >>> obj = CustomObject()
  3490. >>> accelerator.register_for_checkpointing(obj)
  3491. >>> accelerator.save_state("checkpoint.pt")
  3492. ```
  3493. """
  3494. invalid_objects = []
  3495. for obj in objects:
  3496. if not hasattr(obj, "state_dict") or not hasattr(obj, "load_state_dict"):
  3497. invalid_objects.append(obj)
  3498. if len(invalid_objects) > 0:
  3499. err = "All `objects` must include a `state_dict` and `load_state_dict` function to be stored. The following inputs are invalid:"
  3500. for index, obj in enumerate(invalid_objects):
  3501. err += f"\n\t- Item at index {index}, `{get_pretty_name(obj)}`"
  3502. raise ValueError(err)
  3503. self._custom_objects.extend(objects)
  3504. @contextmanager
  3505. def maybe_context_parallel(
  3506. self,
  3507. buffers: list[torch.Tensor] | None = None,
  3508. buffer_seq_dims: list[int] | None = None,
  3509. no_restore_buffers: set[torch.Tensor] | None = None,
  3510. ):
  3511. """
  3512. A context manager that enables context parallel training.
  3513. Args:
  3514. buffers (`list[torch.Tensor]`, `optional`):
  3515. Buffers, which are going to be sharded along the sequence dimension. Common examples are inputs, labels
  3516. or positional embedding buffers. This context manager will modify these buffers in-place, and after
  3517. exiting the context, the buffers will be restored to their original state. To avoid unnecessary
  3518. restores, you can use `no_restore_buffers` to specify which buffers don't need to be restored.
  3519. buffer_seq_dims (`list[int]`, `optional`):
  3520. Sequence dimensions of `buffers`.
  3521. no_restore_buffers (`set[torch.Tensor]`, `optional`):
  3522. This set must be a subset of `buffers`. Specifies which buffers from `buffers` argument won't be
  3523. restored after the context exits. These buffers will be then kept in sharded state.
  3524. <Tip warning={true}>
  3525. `context_parallel` is currently supported with FSDP2 and requires `parallelism_config.cp_size` >
  3526. 1. If either of these conditions are not met, this context manager will have no effect, though to enable fewer
  3527. code changes it will not raise an Exception.
  3528. </Tip>
  3529. <Tip warning={true}>
  3530. This context manager has to be recreated with each training step, as shown in the example below.
  3531. </Tip>
  3532. Example:
  3533. ```python
  3534. >>> for batch in dataloader:
  3535. ... with accelerator.maybe_context_parallel(
  3536. ... buffers=[batch["input_ids"], batch["attention_mask"]],
  3537. ... buffer_seq_dims=[1, 1],
  3538. ... no_restore_buffers={batch["input_ids"]},
  3539. ... ):
  3540. ... outputs = model(batch)
  3541. ... ...
  3542. ```
  3543. """
  3544. # We don't need to check FSDP2 as parallelism_config does that for us
  3545. # Invariant: in this branch self._cp_context is set, as it was set by `self._prepare_cp`
  3546. if (
  3547. self.parallelism_config
  3548. and self.parallelism_config.cp_backend == "torch"
  3549. and self.parallelism_config.cp_enabled
  3550. ):
  3551. with self._cp_context(
  3552. buffers=buffers, buffer_seq_dims=buffer_seq_dims, no_restore_buffers=no_restore_buffers
  3553. ):
  3554. yield
  3555. else:
  3556. logger.warning_once(
  3557. "Context parallel training is not enabled. This context manager will have no effect. "
  3558. "To enable it, set `parallelism_config.cp_size` > 1 in the `Accelerator` constructor."
  3559. )
  3560. yield
  3561. @contextmanager
  3562. def autocast(self, autocast_handler: AutocastKwargs = None):
  3563. """
  3564. Will apply automatic mixed-precision inside the block inside this context manager, if it is enabled. Nothing
  3565. different will happen otherwise.
  3566. A different `autocast_handler` can be passed in to override the one set in the `Accelerator` object. This is
  3567. useful in blocks under `autocast` where you want to revert to fp32.
  3568. Example:
  3569. ```python
  3570. >>> from accelerate import Accelerator
  3571. >>> accelerator = Accelerator(mixed_precision="fp16")
  3572. >>> with accelerator.autocast():
  3573. ... train()
  3574. ```
  3575. """
  3576. if autocast_handler is None:
  3577. autocast_handler = self.autocast_handler
  3578. autocast_context = get_mixed_precision_context_manager(self.native_amp, autocast_handler)
  3579. with autocast_context:
  3580. yield
  3581. @contextmanager
  3582. def profile(self, profile_handler: ProfileKwargs | None = None):
  3583. """
  3584. Will profile the code inside the context manager. The profile will be saved to a Chrome Trace file if
  3585. `profile_handler.output_trace_dir` is set.
  3586. A different `profile_handler` can be passed in to override the one set in the `Accelerator` object.
  3587. Args:
  3588. profile_handler (`ProfileKwargs`, *optional*):
  3589. The profile handler to use for this context manager. If not passed, will use the one set in the
  3590. `Accelerator` object.
  3591. Example:
  3592. ```python
  3593. # Profile with default settings
  3594. from accelerate import Accelerator
  3595. from accelerate.utils import ProfileKwargs
  3596. accelerator = Accelerator()
  3597. with accelerator.profile() as prof:
  3598. train()
  3599. accelerator.print(prof.key_averages().table())
  3600. # Profile with the custom handler
  3601. def custom_handler(prof):
  3602. print(prof.key_averages().table(sort_by="self_cpu_time_total", row_limit=10))
  3603. kwargs = ProfileKwargs(schedule_option=dict(wait=1, warmup=1, active=1), on_trace_ready=custom_handler)
  3604. accelerator = Accelerator(kwarg_handler=[kwargs])
  3605. with accelerator.profile() as prof:
  3606. for _ in range(10):
  3607. train_iteration()
  3608. prof.step()
  3609. # Profile and export to Chrome Trace
  3610. kwargs = ProfileKwargs(output_trace_dir="output_trace")
  3611. accelerator = Accelerator(kwarg_handler=[kwargs])
  3612. with accelerator.profile():
  3613. train()
  3614. ```
  3615. """
  3616. profile_handler = profile_handler or self.profile_handler or ProfileKwargs()
  3617. with profile_handler.build() as profiler:
  3618. yield profiler
  3619. if profile_handler.output_trace_dir is None:
  3620. return
  3621. os.makedirs(profile_handler.output_trace_dir, exist_ok=True)
  3622. profiler.export_chrome_trace(
  3623. os.path.join(profile_handler.output_trace_dir, PROFILE_PATTERN_NAME.format(suffix=self.process_index))
  3624. )
  3625. self.wait_for_everyone()
  3626. @property
  3627. def optimizer_step_was_skipped(self):
  3628. """
  3629. Whether or not the optimizer update was skipped (because of gradient overflow in mixed precision), in which
  3630. case the learning rate should not be changed.
  3631. """
  3632. for optimizer in self._optimizers:
  3633. if optimizer.step_was_skipped:
  3634. return True
  3635. return False
  3636. def skip_first_batches(self, dataloader, num_batches: int = 0):
  3637. """
  3638. Creates a new `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`.
  3639. Args:
  3640. dataloader (`torch.utils.data.DataLoader`): The data loader in which to skip batches.
  3641. num_batches (`int`, *optional*, defaults to 0): The number of batches to skip
  3642. Example:
  3643. ```python
  3644. >>> from accelerate import Accelerator
  3645. >>> accelerator = Accelerator()
  3646. >>> dataloader, model, optimizer, scheduler = accelerator.prepare(dataloader, model, optimizer, scheduler)
  3647. >>> skipped_dataloader = accelerator.skip_first_batches(dataloader, num_batches=2)
  3648. >>> # for the first epoch only
  3649. >>> for input, target in skipped_dataloader:
  3650. ... optimizer.zero_grad()
  3651. ... output = model(input)
  3652. ... loss = loss_func(output, target)
  3653. ... accelerator.backward(loss)
  3654. ... optimizer.step()
  3655. >>> # subsequent epochs
  3656. >>> for input, target in dataloader:
  3657. ... optimizer.zero_grad()
  3658. ... ...
  3659. ```
  3660. """
  3661. return skip_first_batches(dataloader, num_batches=num_batches)
  3662. def __deepcopy__(self, memo):
  3663. logger.info("Deep copying the `Accelerator` object, note that this will point to the same original object.")
  3664. return self
  3665. def verify_device_map(self, model: torch.nn.Module) -> bool:
  3666. """
  3667. Verifies that `model` has not been prepared with big model inference with a device-map resembling `auto`.
  3668. """
  3669. # Checks if any of the child modules has the attribute `hf_device_map` and this map has more than one entry.
  3670. for m in model.modules():
  3671. if hasattr(m, "hf_device_map") and len(m.hf_device_map) > 1:
  3672. return True
  3673. return False
  3674. def lomo_backward(self, loss: torch.Tensor, learning_rate: float) -> None:
  3675. """
  3676. Runs backward pass on LOMO optimizers.
  3677. """
  3678. if is_lomo_available():
  3679. # We need to import locally to avoid circular imports since lomo imports stuff from
  3680. # transformers & accelerate
  3681. from lomo_optim import AdaLomo, Lomo
  3682. if learning_rate is None:
  3683. raise ValueError("A learning rate must be passed in order to call backward pass with LOMO optimizers.")
  3684. _backward_called = False
  3685. for optimizer in self._optimizers:
  3686. if isinstance(optimizer.optimizer, (Lomo, AdaLomo)):
  3687. optimizer.optimizer.fused_backward(loss, learning_rate)
  3688. _backward_called = True
  3689. if not _backward_called:
  3690. raise ValueError(
  3691. "Backward pass not properly called on LOMO optimizers. Are you sure you passed a LOMO optimizer in accelerator.prepare()?"
  3692. )
  3693. @property
  3694. def fp8_backend(self) -> FP8BackendType:
  3695. "Returns the configured backend for training in FP8"
  3696. if self.has_fp8_handler:
  3697. if self.fp8_recipe_handler is not None:
  3698. return FP8BackendType(self.fp8_recipe_handler.backend)
  3699. elif self.ao_recipe_handler is not None:
  3700. return FP8BackendType.AO
  3701. elif self.te_recipe_handler is not None:
  3702. return FP8BackendType.TE
  3703. elif self.msamp_recipe_handler is not None:
  3704. return FP8BackendType.MSAMP
  3705. elif self.state.deepspeed_plugin is not None and self.state.deepspeed_plugin.enable_msamp:
  3706. return FP8BackendType.MSAMP
  3707. return FP8BackendType(parse_choice_from_env("ACCELERATE_FP8_BACKEND", "NO"))