dataclasses.py 137 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817281828192820282128222823282428252826282728282829283028312832283328342835283628372838283928402841284228432844284528462847284828492850285128522853285428552856285728582859286028612862286328642865286628672868286928702871287228732874287528762877287828792880288128822883288428852886288728882889289028912892289328942895289628972898289929002901290229032904290529062907290829092910291129122913291429152916291729182919292029212922292329242925292629272928292929302931293229332934293529362937293829392940294129422943294429452946294729482949295029512952295329542955295629572958295929602961296229632964296529662967296829692970297129722973297429752976297729782979298029812982298329842985
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. General namespace and dataclass related classes
  16. """
  17. import argparse
  18. import copy
  19. import enum
  20. import functools
  21. import logging
  22. import os
  23. import warnings
  24. from collections.abc import Iterable
  25. from contextlib import contextmanager
  26. from dataclasses import dataclass, field
  27. from datetime import timedelta
  28. from typing import TYPE_CHECKING, Any, Callable, Literal, Optional, Union, get_args
  29. import torch
  30. from .constants import (
  31. BETA_CP_AVAILABLE_PYTORCH_VERSION,
  32. BETA_TP_AVAILABLE_PYTORCH_VERSION,
  33. BETA_TP_AVAILABLE_TRANSFORMERS_VERSION,
  34. FSDP2_PYTORCH_VERSION,
  35. FSDP_AUTO_WRAP_POLICY,
  36. FSDP_BACKWARD_PREFETCH,
  37. FSDP_SHARDING_STRATEGY,
  38. MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
  39. XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
  40. )
  41. from .environment import parse_flag_from_env, str_to_bool
  42. from .imports import (
  43. is_cuda_available,
  44. is_hpu_available,
  45. is_mlu_available,
  46. is_msamp_available,
  47. is_musa_available,
  48. is_npu_available,
  49. is_transformer_engine_available,
  50. is_xpu_available,
  51. )
  52. from .versions import compare_versions, is_torch_version
  53. if TYPE_CHECKING:
  54. # Mock imports for type checking
  55. from torchao.float8 import Float8LinearConfig
  56. logger = logging.getLogger(__name__)
  57. class KwargsHandler:
  58. """
  59. Internal mixin that implements a `to_kwargs()` method for a dataclass.
  60. """
  61. def to_dict(self):
  62. return copy.deepcopy(self.__dict__)
  63. def to_kwargs(self):
  64. """
  65. Returns a dictionary containing the attributes with values different from the default of this class.
  66. """
  67. # import clear_environment here to avoid circular import problem
  68. from .environment import clear_environment
  69. with clear_environment():
  70. default_dict = self.__class__().to_dict()
  71. this_dict = self.to_dict()
  72. return {k: v for k, v in this_dict.items() if default_dict[k] != v}
  73. class EnumWithContains(enum.EnumMeta):
  74. "A metaclass that adds the ability to check if `self` contains an item with the `in` operator"
  75. def __contains__(cls, item):
  76. try:
  77. cls(item)
  78. except ValueError:
  79. return False
  80. return True
  81. class BaseEnum(enum.Enum, metaclass=EnumWithContains):
  82. "An enum class that can get the value of an item with `str(Enum.key)`"
  83. def __str__(self):
  84. return self.value
  85. @classmethod
  86. def list(cls):
  87. "Method to list all the possible items in `cls`"
  88. return list(map(str, cls))
  89. @dataclass
  90. class AutocastKwargs(KwargsHandler):
  91. """
  92. Use this object in your [`Accelerator`] to customize how `torch.autocast` behaves. Please refer to the
  93. documentation of this [context manager](https://pytorch.org/docs/stable/amp.html#torch.autocast) for more
  94. information on each argument.
  95. Example:
  96. ```python
  97. from accelerate import Accelerator
  98. from accelerate.utils import AutocastKwargs
  99. kwargs = AutocastKwargs(cache_enabled=True)
  100. accelerator = Accelerator(kwargs_handlers=[kwargs])
  101. ```
  102. """
  103. enabled: bool = True
  104. cache_enabled: Optional[bool] = None
  105. class DDPCommunicationHookType(BaseEnum):
  106. """
  107. Represents a type of communication hook used in DDP.
  108. Values:
  109. - **NO** -- no communication hook
  110. - **FP16** -- DDP communication hook to compress the gradients in FP16
  111. - **BF16** -- DDP communication hook to compress the gradients in BF16
  112. - **POWER_SGD** -- DDP communication hook to use PowerSGD
  113. - **BATCHED_POWER_SGD** -- DDP communication hook to use batched PowerSGD
  114. """
  115. NO = "no"
  116. FP16 = "fp16"
  117. BF16 = "bf16"
  118. POWER_SGD = "power_sgd"
  119. BATCHED_POWER_SGD = "batched_power_sgd"
  120. @dataclass
  121. class DistributedDataParallelKwargs(KwargsHandler):
  122. """
  123. Use this object in your [`Accelerator`] to customize how your model is wrapped in a
  124. `torch.nn.parallel.DistributedDataParallel`. Please refer to the documentation of this
  125. [wrapper](https://pytorch.org/docs/stable/generated/torch.nn.parallel.DistributedDataParallel.html) for more
  126. information on each argument.
  127. <Tip warning={true}>
  128. `gradient_as_bucket_view` is only available in PyTorch 1.7.0 and later versions.
  129. `static_graph` is only available in PyTorch 1.11.0 and later versions.
  130. </Tip>
  131. Example:
  132. ```python
  133. from accelerate import Accelerator
  134. from accelerate.utils import DistributedDataParallelKwargs
  135. kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
  136. accelerator = Accelerator(kwargs_handlers=[kwargs])
  137. ```
  138. """
  139. dim: int = 0
  140. broadcast_buffers: bool = True
  141. bucket_cap_mb: int = 25
  142. find_unused_parameters: bool = False
  143. check_reduction: bool = False
  144. gradient_as_bucket_view: bool = False
  145. static_graph: bool = False
  146. comm_hook: DDPCommunicationHookType = DDPCommunicationHookType.NO
  147. comm_wrapper: Literal[
  148. DDPCommunicationHookType.NO,
  149. DDPCommunicationHookType.FP16,
  150. DDPCommunicationHookType.BF16,
  151. ] = DDPCommunicationHookType.NO
  152. comm_state_option: dict = field(default_factory=dict)
  153. def to_dict(self, ignore_keys=("comm_hook", "comm_wrapper", "comm_state_option")):
  154. return {k: v for k, v in super().to_dict().items() if k not in ignore_keys}
  155. def register_comm_hook(self, model):
  156. from torch.distributed.algorithms.ddp_comm_hooks import (
  157. default_hooks,
  158. powerSGD_hook,
  159. )
  160. hook_map: dict[DDPCommunicationHookType, Callable] = {
  161. DDPCommunicationHookType.FP16: default_hooks.fp16_compress_hook,
  162. DDPCommunicationHookType.BF16: default_hooks.bf16_compress_hook,
  163. DDPCommunicationHookType.POWER_SGD: powerSGD_hook.powerSGD_hook,
  164. DDPCommunicationHookType.BATCHED_POWER_SGD: powerSGD_hook.batched_powerSGD_hook,
  165. }
  166. wrapper_map: dict[DDPCommunicationHookType, Callable] = {
  167. DDPCommunicationHookType.FP16: default_hooks.fp16_compress_wrapper,
  168. DDPCommunicationHookType.BF16: default_hooks.bf16_compress_wrapper,
  169. }
  170. hook: Optional[Callable] = hook_map.get(self.comm_hook)
  171. wrapper: Optional[Callable] = wrapper_map.get(self.comm_wrapper)
  172. if hook and wrapper:
  173. hook = wrapper(hook)
  174. if hook:
  175. state = (
  176. powerSGD_hook.PowerSGDState(None, **self.comm_state_option)
  177. if self.comm_hook
  178. in (
  179. DDPCommunicationHookType.POWER_SGD,
  180. DDPCommunicationHookType.BATCHED_POWER_SGD,
  181. )
  182. else None
  183. )
  184. model.register_comm_hook(
  185. state=state,
  186. hook=hook,
  187. )
  188. @dataclass
  189. class GradScalerKwargs(KwargsHandler):
  190. """
  191. Use this object in your [`Accelerator`] to customize the behavior of mixed precision, specifically how the
  192. `torch.amp.GradScaler` or `torch.cuda.amp.GradScaler` used is created. Please refer to the documentation of this
  193. [scaler](https://pytorch.org/docs/stable/amp.html?highlight=gradscaler) for more information on each argument.
  194. <Tip warning={true}>
  195. `torch.cuda.amp.GradScaler` is only available in PyTorch 1.5.0 and later versions, and `torch.amp.GradScaler` is
  196. only available in PyTorch 2.4.0 and later versions.
  197. </Tip>
  198. Example:
  199. ```python
  200. from accelerate import Accelerator
  201. from accelerate.utils import GradScalerKwargs
  202. kwargs = GradScalerKwargs(backoff_factor=0.25)
  203. accelerator = Accelerator(kwargs_handlers=[kwargs])
  204. ```
  205. """
  206. init_scale: float = 65536.0
  207. growth_factor: float = 2.0
  208. backoff_factor: float = 0.5
  209. growth_interval: int = 2000
  210. enabled: bool = True
  211. @dataclass
  212. class InitProcessGroupKwargs(KwargsHandler):
  213. """
  214. Use this object in your [`Accelerator`] to customize the initialization of the distributed processes. Please refer
  215. to the documentation of this
  216. [method](https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group) for more
  217. information on each argument.
  218. Note: If `timeout` is set to `None`, the default will be based upon how `backend` is set.
  219. ```python
  220. from datetime import timedelta
  221. from accelerate import Accelerator
  222. from accelerate.utils import InitProcessGroupKwargs
  223. kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=800))
  224. accelerator = Accelerator(kwargs_handlers=[kwargs])
  225. ```
  226. """
  227. backend: Optional[str] = "nccl"
  228. init_method: Optional[str] = None
  229. timeout: Optional[timedelta] = None
  230. def __post_init__(self):
  231. if self.timeout is None:
  232. seconds = 1800 if self.backend != "nccl" else 600
  233. self.timeout = timedelta(seconds=seconds)
  234. # Literals
  235. Backend = Literal["MSAMP", "TE"]
  236. OptLevel = Literal["O1", "O2"]
  237. FP8Format = Literal["HYBRID", "E4M3", "E5M2"]
  238. AmaxComputeAlgorithm = Literal["max", "most_recent"]
  239. # FP8 training recipe kwargs
  240. @dataclass
  241. class AORecipeKwargs(KwargsHandler):
  242. """
  243. Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
  244. training with `torchao` FP8.
  245. Args:
  246. config (`torchao.float8.Float8LinearConfig`, *optional*, default to `None`):
  247. The configuration for the FP8 training. In general, the default config should be sufficient.
  248. module_filter_func (`Callable`, *optional*, default to `None`):
  249. Optional function that must take in a module and layer name, and returns a boolean indicating whether the
  250. module should be converted to FP8. Defaults to `accelerate.utils.ao.filter_linear_layers`. See it for an
  251. example.
  252. """
  253. config: Optional["Float8LinearConfig"] = None
  254. module_filter_func: Optional[Callable] = None
  255. @dataclass
  256. class TERecipeKwargs(KwargsHandler):
  257. """
  258. Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
  259. training with `transformer-engine`.
  260. <Tip>
  261. For more information on the args, please refer to the API
  262. [documentation](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/common.html).
  263. </Tip>
  264. ```python
  265. from accelerate import Accelerator
  266. from accelerate.utils import TERecipeKwargs
  267. kwargs = TERecipeKwargs(fp8_format="HYBRID")
  268. accelerator = Accelerator(mixed_precision="fp8", kwargs_handlers=[kwargs])
  269. ```
  270. Args:
  271. use_autocast_during_eval (`bool`, *optional*, default to `False`):
  272. Whether to use FP8 autocast during eval mode. Generally better metrics are found when this is `False`.
  273. margin (`int`, *optional*, default to 0):
  274. The margin to use for the gradient scaling.
  275. interval (`int`, *optional*, default to 1):
  276. The interval to use for how often the scaling factor is recomputed.
  277. fp8_format (`str`, *optional*, default to "HYBRID"):
  278. The format to use for the FP8 recipe. Must be one of `HYBRID`, `E4M3` or `E5M2`. (Generally `HYBRID` for
  279. training, `E4M3` or `E5M2` for evaluation)
  280. amax_history_len (`int`, *optional*, default to 1024):
  281. The length of the history to use for the scaling factor computation
  282. amax_compute_algo (`str`, *optional*, default to "most_recent"):
  283. The algorithm to use for the scaling factor computation. Must be one of `max` or `most_recent`.
  284. override_linear_precision (`tuple` of three `bool`, *optional*, default to `(False, False, False)`):
  285. Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision.
  286. """
  287. use_autocast_during_eval: Optional[bool] = None
  288. margin: Optional[int] = None
  289. interval: Optional[int] = None
  290. fp8_format: FP8Format = None
  291. amax_history_len: Optional[int] = None
  292. amax_compute_algo: AmaxComputeAlgorithm = None
  293. override_linear_precision: tuple[bool, bool, bool] = None
  294. use_mxfp8_block_scaling: Optional[bool] = None
  295. def __post_init__(self):
  296. env_prefix = "ACCELERATE_FP8_"
  297. if not is_transformer_engine_available():
  298. raise ImportError("TransformerEngine is not available. Please install it or use a different backend.")
  299. if self.use_autocast_during_eval is None:
  300. self.use_autocast_during_eval = parse_flag_from_env(env_prefix + "USE_AUTOCAST_DURING_EVAL")
  301. if self.margin is None:
  302. self.margin = int(os.environ.get(env_prefix + "MARGIN", 0))
  303. if self.interval is None:
  304. self.interval = int(os.environ.get(env_prefix + "INTERVAL", 1))
  305. if self.fp8_format is None:
  306. self.fp8_format = os.environ.get(env_prefix + "FORMAT", "HYBRID")
  307. self.fp8_format = self.fp8_format.upper()
  308. if self.fp8_format not in get_args(FP8Format):
  309. raise ValueError(f"`fp8_format` must be one of {' or '.join(get_args(FP8Format))}.")
  310. if self.amax_compute_algo is None:
  311. self.amax_compute_algo = os.environ.get(env_prefix + "AMAX_COMPUTE_ALGO", "most_recent")
  312. self.amax_compute_algo = self.amax_compute_algo.lower()
  313. if self.amax_compute_algo not in get_args(AmaxComputeAlgorithm):
  314. raise ValueError(f"`amax_compute_algo` must be one of {' or '.join(get_args(AmaxComputeAlgorithm))}")
  315. if self.amax_history_len is None:
  316. self.amax_history_len = int(os.environ.get(env_prefix + "AMAX_HISTORY_LEN", 1024))
  317. if self.override_linear_precision is None:
  318. fprop = parse_flag_from_env(env_prefix + "OVERRIDE_FPROP")
  319. dgrad = parse_flag_from_env(env_prefix + "OVERRIDE_DGRAD")
  320. wgrad = parse_flag_from_env(env_prefix + "OVERRIDE_WGRAD")
  321. self.override_linear_precision = (fprop, dgrad, wgrad)
  322. if self.use_mxfp8_block_scaling is None:
  323. self.use_mxfp8_block_scaling = parse_flag_from_env(env_prefix + "USE_MXFP8_BLOCK_SCALING")
  324. @dataclass
  325. class MSAMPRecipeKwargs(KwargsHandler):
  326. """
  327. Use this object in your [`Accelerator`] to customize the initialization of the recipe for FP8 mixed precision
  328. training with `ms-amp`.
  329. """
  330. opt_level: OptLevel = None
  331. def __post_init__(self):
  332. env_prefix = "ACCELERATE_FP8_"
  333. if self.opt_level is None:
  334. self.opt_level = os.environ.get(env_prefix + "OPT_LEVEL", "O2")
  335. if self.opt_level not in get_args(OptLevel):
  336. raise ValueError(f"`opt_level` must be one of {' or '.join(get_args(OptLevel))}")
  337. @dataclass
  338. class FP8RecipeKwargs(TERecipeKwargs, MSAMPRecipeKwargs):
  339. """
  340. Deprecated. Please use one of the proper FP8 recipe kwargs classes such as `TERecipeKwargs` or `MSAMPRecipeKwargs`
  341. instead.
  342. """
  343. backend: Backend = None
  344. def __post_init__(self):
  345. env_prefix = "ACCELERATE_FP8_"
  346. warnings.warn(
  347. "FP8RecipeKwargs is deprecated and will be removed in Accelerate v2.0.0. "
  348. "Please use one of the proper FP8 recipe kwargs classes such as TERecipeKwargs or MSAMPRecipeKwargs instead.",
  349. FutureWarning,
  350. )
  351. default_backend = "msamp" if is_msamp_available() else "te"
  352. if self.backend is None:
  353. self.backend = os.environ.get(env_prefix + "BACKEND", default_backend)
  354. self.backend = self.backend.upper()
  355. if self.backend not in get_args(Backend):
  356. raise ValueError("`backend` must be 'MSAMP' or 'TE' (TransformerEngine) to use `FP8RecipeKwargs`.")
  357. super().__post_init__()
  358. # Literal
  359. ProfilerActivity = Literal["cpu", "xpu", "mtia", "cuda", "hpu"]
  360. @dataclass
  361. class ProfileKwargs(KwargsHandler):
  362. """
  363. Use this object in your [`Accelerator`] to customize the initialization of the profiler. Please refer to the
  364. documentation of this [context manager](https://pytorch.org/docs/stable/profiler.html#torch.profiler.profile) for
  365. more information on each argument.
  366. <Tip warning={true}>
  367. `torch.profiler` is only available in PyTorch 1.8.1 and later versions.
  368. </Tip>
  369. Example:
  370. ```python
  371. from accelerate import Accelerator
  372. from accelerate.utils import ProfileKwargs
  373. kwargs = ProfileKwargs(activities=["cpu", "cuda"])
  374. accelerator = Accelerator(kwargs_handlers=[kwargs])
  375. ```
  376. Args:
  377. activities (`List[str]`, *optional*, default to `None`):
  378. The list of activity groups to use in profiling. Must be one of `"cpu"`, `"xpu"`, `"mtia"`, "hpu" or
  379. `"cuda"`.
  380. schedule_option (`Dict[str, int]`, *optional*, default to `None`):
  381. The schedule option to use for the profiler. Available keys are `wait`, `warmup`, `active`, `repeat` and
  382. `skip_first`. The profiler will skip the first `skip_first` steps, then wait for `wait` steps, then do the
  383. warmup for the next `warmup` steps, then do the active recording for the next `active` steps and then
  384. repeat the cycle starting with `wait` steps. The optional number of cycles is specified with the `repeat`
  385. parameter, the zero value means that the cycles will continue until the profiling is finished.
  386. on_trace_ready (`Callable`, *optional*, default to `None`):
  387. Callable that is called at each step when schedule returns `ProfilerAction.RECORD_AND_SAVE` during the
  388. profiling.
  389. record_shapes (`bool`, *optional*, default to `False`):
  390. Save information about operator’s input shapes.
  391. profile_memory (`bool`, *optional*, default to `False`):
  392. Track tensor memory allocation/deallocation
  393. with_stack (`bool`, *optional*, default to `False`):
  394. Record source information (file and line number) for the ops.
  395. with_flops (`bool`, *optional*, default to `False`):
  396. Use formula to estimate the FLOPS of specific operators
  397. with_modules (`bool`, *optional*, default to `False`):
  398. Record module hierarchy (including function names) corresponding to the callstack of the op.
  399. output_trace_dir (`str`, *optional*, default to `None`):
  400. Exports the collected trace in Chrome JSON format. Chrome use 'chrome://tracing' view json file. Defaults
  401. to None, which means profiling does not store json files.
  402. """
  403. activities: Optional[list[ProfilerActivity]] = None
  404. schedule_option: Optional[dict[str, int]] = None
  405. on_trace_ready: Optional[Callable] = None
  406. record_shapes: bool = False
  407. profile_memory: bool = False
  408. with_stack: bool = False
  409. with_flops: bool = False
  410. with_modules: bool = False
  411. output_trace_dir: Optional[str] = None
  412. def _get_profiler_activity(self, activity: ProfilerActivity) -> torch.profiler.ProfilerActivity:
  413. """Get the profiler activity from the string.
  414. Args:
  415. activity (str): The profiler activity name.
  416. Returns:
  417. torch.profiler.ProfilerActivity: The profiler activity.
  418. """
  419. profiler_activity_map: dict[str, torch.profiler.ProfilerActivity] = {
  420. "cpu": torch.profiler.ProfilerActivity.CPU,
  421. "cuda": torch.profiler.ProfilerActivity.CUDA,
  422. }
  423. if is_hpu_available():
  424. profiler_activity_map["hpu"] = torch.profiler.ProfilerActivity.HPU
  425. if is_torch_version(">=", XPU_PROFILING_AVAILABLE_PYTORCH_VERSION):
  426. if torch.xpu.is_available():
  427. profiler_activity_map["xpu"] = torch.profiler.ProfilerActivity.XPU
  428. if is_torch_version(">=", MITA_PROFILING_AVAILABLE_PYTORCH_VERSION):
  429. if torch.mtia.is_available():
  430. profiler_activity_map["mtia"] = torch.profiler.ProfilerActivity.MTIA
  431. if activity not in profiler_activity_map:
  432. raise ValueError(f"Invalid profiler activity: {activity}. Must be one of {list(profiler_activity_map)}.")
  433. return profiler_activity_map[activity]
  434. def build(self) -> torch.profiler.profile:
  435. """
  436. Build a profiler object with the current configuration.
  437. Returns:
  438. torch.profiler.profile: The profiler object.
  439. """
  440. activities: Optional[list[ProfilerActivity]] = None
  441. if self.activities is not None:
  442. activities = [self._get_profiler_activity(activity) for activity in self.activities]
  443. schedule: Optional[torch.profiler.schedule] = None
  444. if self.schedule_option is not None:
  445. schedule = torch.profiler.schedule(**self.schedule_option)
  446. return torch.profiler.profile(
  447. activities=activities,
  448. schedule=schedule,
  449. on_trace_ready=self.on_trace_ready,
  450. record_shapes=self.record_shapes,
  451. profile_memory=self.profile_memory,
  452. with_stack=self.with_stack,
  453. with_flops=self.with_flops,
  454. with_modules=self.with_modules,
  455. )
  456. class DistributedType(str, enum.Enum):
  457. """
  458. Represents a type of distributed environment.
  459. Values:
  460. - **NO** -- Not a distributed environment, just a single process.
  461. - **MULTI_CPU** -- Distributed on multiple CPU nodes.
  462. - **MULTI_GPU** -- Distributed on multiple GPUs.
  463. - **MULTI_MLU** -- Distributed on multiple MLUs.
  464. - **MULTI_SDAA** -- Distributed on multiple SDAAs.
  465. - **MULTI_MUSA** -- Distributed on multiple MUSAs.
  466. - **MULTI_NPU** -- Distributed on multiple NPUs.
  467. - **MULTI_XPU** -- Distributed on multiple XPUs.
  468. - **MULTI_HPU** -- Distributed on multiple HPUs.
  469. - **DEEPSPEED** -- Using DeepSpeed.
  470. - **XLA** -- Using TorchXLA.
  471. """
  472. # Subclassing str as well as Enum allows the `DistributedType` to be JSON-serializable out of the box.
  473. NO = "NO"
  474. MULTI_CPU = "MULTI_CPU"
  475. MULTI_GPU = "MULTI_GPU"
  476. MULTI_NPU = "MULTI_NPU"
  477. MULTI_MLU = "MULTI_MLU"
  478. MULTI_SDAA = "MULTI_SDAA"
  479. MULTI_MUSA = "MULTI_MUSA"
  480. MULTI_XPU = "MULTI_XPU"
  481. DEEPSPEED = "DEEPSPEED"
  482. FSDP = "FSDP"
  483. XLA = "XLA"
  484. MEGATRON_LM = "MEGATRON_LM"
  485. MULTI_HPU = "MULTI_HPU"
  486. class SageMakerDistributedType(str, enum.Enum):
  487. """
  488. Represents a type of distributed environment.
  489. Values:
  490. - **NO** -- Not a distributed environment, just a single process.
  491. - **DATA_PARALLEL** -- using sagemaker distributed data parallelism.
  492. - **MODEL_PARALLEL** -- using sagemaker distributed model parallelism.
  493. """
  494. # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box.
  495. NO = "NO"
  496. DATA_PARALLEL = "DATA_PARALLEL"
  497. MODEL_PARALLEL = "MODEL_PARALLEL"
  498. class FP8BackendType(str, enum.Enum):
  499. """
  500. Represents the backend used for FP8.
  501. Values:
  502. - **TE** -- using TransformerEngine.
  503. - **MSAMP** -- using msamp.
  504. """
  505. # Subclassing str as well as Enum allows the `FP8BackendType` to be JSON-serializable out of the box.
  506. NO = "NO"
  507. TE = "TE"
  508. MSAMP = "MSAMP"
  509. AO = "AO"
  510. class ComputeEnvironment(str, enum.Enum):
  511. """
  512. Represents a type of the compute environment.
  513. Values:
  514. - **LOCAL_MACHINE** -- private/custom cluster hardware.
  515. - **AMAZON_SAGEMAKER** -- Amazon SageMaker as compute environment.
  516. """
  517. # Subclassing str as well as Enum allows the `ComputeEnvironment` to be JSON-serializable out of the box.
  518. LOCAL_MACHINE = "LOCAL_MACHINE"
  519. AMAZON_SAGEMAKER = "AMAZON_SAGEMAKER"
  520. class DynamoBackend(str, BaseEnum):
  521. """
  522. Represents a dynamo backend (see https://pytorch.org/docs/stable/torch.compiler.html).
  523. Values:
  524. - **NO** -- Do not use torch dynamo.
  525. - **EAGER** -- Uses PyTorch to run the extracted GraphModule. This is quite useful in debugging TorchDynamo
  526. issues.
  527. - **AOT_EAGER** -- Uses AotAutograd with no compiler, i.e, just using PyTorch eager for the AotAutograd's
  528. extracted forward and backward graphs. This is useful for debugging, and unlikely to give speedups.
  529. - **INDUCTOR** -- Uses TorchInductor backend with AotAutograd and cudagraphs by leveraging codegened Triton
  530. kernels. [Read
  531. more](https://dev-discuss.pytorch.org/t/torchinductor-a-pytorch-native-compiler-with-define-by-run-ir-and-symbolic-shapes/747)
  532. - **AOT_TS_NVFUSER** -- nvFuser with AotAutograd/TorchScript. [Read
  533. more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
  534. - **NVPRIMS_NVFUSER** -- nvFuser with PrimTorch. [Read
  535. more](https://dev-discuss.pytorch.org/t/tracing-with-primitives-update-1-nvfuser-and-its-primitives/593)
  536. - **CUDAGRAPHS** -- cudagraphs with AotAutograd. [Read more](https://github.com/pytorch/torchdynamo/pull/757)
  537. - **OFI** -- Uses Torchscript optimize_for_inference. Inference only. [Read
  538. more](https://pytorch.org/docs/stable/generated/torch.jit.optimize_for_inference.html)
  539. - **FX2TRT** -- Uses Nvidia TensorRT for inference optimizations. Inference only. [Read
  540. more](https://github.com/pytorch/TensorRT/blob/master/docsrc/tutorials/getting_started_with_fx_path.rst)
  541. - **ONNXRT** -- Uses ONNXRT for inference on CPU/GPU. Inference only. [Read more](https://onnxruntime.ai/)
  542. - **TENSORRT** -- Uses ONNXRT to run TensorRT for inference optimizations. [Read
  543. more](https://github.com/onnx/onnx-tensorrt)
  544. - **AOT_TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for training. [Read
  545. more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md)
  546. - **TORCHXLA_TRACE_ONCE** -- Uses Pytorch/XLA with TorchDynamo optimization, for inference. [Read
  547. more](https://github.com/pytorch/xla/blob/r2.0/docs/dynamo.md)
  548. - **IPEX** -- Uses IPEX for inference on CPU. Inference only. [Read
  549. more](https://github.com/intel/intel-extension-for-pytorch).
  550. - **TVM** -- Uses Apache TVM for inference optimizations. [Read more](https://tvm.apache.org/)
  551. - **HPU_BACKEND** -- Uses HPU backend for inference optimizations.
  552. """
  553. # Subclassing str as well as Enum allows the `SageMakerDistributedType` to be JSON-serializable out of the box.
  554. NO = "NO"
  555. EAGER = "EAGER"
  556. AOT_EAGER = "AOT_EAGER"
  557. INDUCTOR = "INDUCTOR"
  558. AOT_TS_NVFUSER = "AOT_TS_NVFUSER"
  559. NVPRIMS_NVFUSER = "NVPRIMS_NVFUSER"
  560. CUDAGRAPHS = "CUDAGRAPHS"
  561. OFI = "OFI"
  562. FX2TRT = "FX2TRT"
  563. ONNXRT = "ONNXRT"
  564. TENSORRT = "TENSORRT"
  565. AOT_TORCHXLA_TRACE_ONCE = "AOT_TORCHXLA_TRACE_ONCE"
  566. TORCHXLA_TRACE_ONCE = "TORCHXLA_TRACE_ONCE"
  567. IPEX = "IPEX"
  568. TVM = "TVM"
  569. HPU_BACKEND = "HPU_BACKEND"
  570. class LoggerType(BaseEnum):
  571. """Represents a type of supported experiment tracker
  572. Values:
  573. - **ALL** -- all available trackers in the environment that are supported
  574. - **TENSORBOARD** -- TensorBoard as an experiment tracker
  575. - **WANDB** -- wandb as an experiment tracker
  576. - **TRACKIO** -- trackio as an experiment tracker
  577. - **COMETML** -- comet_ml as an experiment tracker
  578. - **MLFLOW** -- mlflow as an experiment tracker
  579. - **CLEARML** -- clearml as an experiment tracker
  580. - **DVCLIVE** -- dvclive as an experiment tracker
  581. - **SWANLAB** -- swanlab as an experiment tracker
  582. """
  583. ALL = "all"
  584. AIM = "aim"
  585. TENSORBOARD = "tensorboard"
  586. WANDB = "wandb"
  587. TRACKIO = "trackio"
  588. COMETML = "comet_ml"
  589. MLFLOW = "mlflow"
  590. CLEARML = "clearml"
  591. DVCLIVE = "dvclive"
  592. SWANLAB = "swanlab"
  593. class PrecisionType(str, BaseEnum):
  594. """Represents a type of precision used on floating point values
  595. Values:
  596. - **NO** -- using full precision (FP32)
  597. - **FP16** -- using half precision
  598. - **BF16** -- using brain floating point precision
  599. """
  600. NO = "no"
  601. FP8 = "fp8"
  602. FP16 = "fp16"
  603. BF16 = "bf16"
  604. class RNGType(BaseEnum):
  605. TORCH = "torch"
  606. CUDA = "cuda"
  607. MLU = "mlu"
  608. SDAA = "sdaa"
  609. MUSA = "musa"
  610. NPU = "npu"
  611. XLA = "xla"
  612. XPU = "xpu"
  613. HPU = "hpu"
  614. GENERATOR = "generator"
  615. class CustomDtype(enum.Enum):
  616. r"""
  617. An enum that contains multiple custom dtypes that can be used for `infer_auto_device_map`.
  618. """
  619. FP8 = "fp8"
  620. INT4 = "int4"
  621. INT2 = "int2"
  622. # data classes
  623. @dataclass
  624. class TensorInformation:
  625. shape: torch.Size
  626. dtype: torch.dtype
  627. @dataclass
  628. class DataLoaderConfiguration:
  629. """
  630. Configuration for dataloader-related items when calling `accelerator.prepare`.
  631. Args:
  632. split_batches (`bool`, defaults to `False`):
  633. Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If
  634. `True`, the actual batch size used will be the same on any kind of distributed processes, but it must be a
  635. round multiple of `num_processes` you are using. If `False`, actual batch size used will be the one set in
  636. your script multiplied by the number of processes.
  637. dispatch_batches (`bool`, defaults to `None`):
  638. If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process
  639. and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose
  640. underlying dataset is an `IterableDataset`, `False` otherwise.
  641. even_batches (`bool`, defaults to `True`):
  642. If set to `True`, in cases where the total batch size across all processes does not exactly divide the
  643. dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among
  644. all workers.
  645. use_seedable_sampler (`bool`, defaults to `False`):
  646. Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`]). Ensures
  647. training results are fully reproducible using a different sampling technique. While seed-to-seed results
  648. may differ, on average the differences are negligible when using multiple different seeds to compare.
  649. Should also be ran with [`~utils.set_seed`] for the best results.
  650. data_seed (`int`, defaults to `None`):
  651. The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator
  652. will use the current default seed from torch.
  653. non_blocking (`bool`, defaults to `False`):
  654. If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device
  655. transfers, allowing for better overlap between dataloader communication and computation. Recommended that
  656. the prepared dataloader has `pin_memory` set to `True` to work properly.
  657. use_stateful_dataloader (`bool`, defaults to `False`):
  658. If set to `True`, the dataloader prepared by the Accelerator will be backed by
  659. [torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader).
  660. This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed.
  661. """
  662. split_batches: bool = field(
  663. default=False,
  664. metadata={
  665. "help": "Whether or not the accelerator should split the batches yielded by the dataloaders across the devices. If"
  666. " `True` the actual batch size used will be the same on any kind of distributed processes, but it must be a"
  667. " round multiple of the `num_processes` you are using. If `False`, actual batch size used will be the one set"
  668. " in your script multiplied by the number of processes."
  669. },
  670. )
  671. dispatch_batches: bool = field(
  672. default=None,
  673. metadata={
  674. "help": "If set to `True`, the dataloader prepared by the Accelerator is only iterated through on the main process"
  675. " and then the batches are split and broadcast to each process. Will default to `True` for `DataLoader` whose"
  676. " underlying dataset is an `IterableDataset`, `False` otherwise."
  677. },
  678. )
  679. even_batches: bool = field(
  680. default=True,
  681. metadata={
  682. "help": "If set to `True`, in cases where the total batch size across all processes does not exactly divide the"
  683. " dataset, samples at the start of the dataset will be duplicated so the batch can be divided equally among"
  684. " all workers."
  685. },
  686. )
  687. use_seedable_sampler: bool = field(
  688. default=False,
  689. metadata={
  690. "help": "Whether or not use a fully seedable random sampler ([`data_loader.SeedableRandomSampler`])."
  691. "Ensures training results are fully reproducible using a different sampling technique. "
  692. "While seed-to-seed results may differ, on average the differences are negligible when using"
  693. "multiple different seeds to compare. Should also be ran with [`~utils.set_seed`] for the best results."
  694. },
  695. )
  696. data_seed: int = field(
  697. default=None,
  698. metadata={
  699. "help": "The seed to use for the underlying generator when using `use_seedable_sampler`. If `None`, the generator"
  700. " will use the current default seed from torch."
  701. },
  702. )
  703. non_blocking: bool = field(
  704. default=False,
  705. metadata={
  706. "help": "If set to `True`, the dataloader prepared by the Accelerator will utilize non-blocking host-to-device"
  707. " transfers, allowing for better overlap between dataloader communication and computation. Recommended that the"
  708. " prepared dataloader has `pin_memory` set to `True` to work properly."
  709. },
  710. )
  711. use_stateful_dataloader: bool = field(
  712. default=False,
  713. metadata={
  714. "help": "If set to `True`, the dataloader prepared by the Accelerator will be backed by "
  715. "[torchdata.StatefulDataLoader](https://github.com/pytorch/data/tree/main/torchdata/stateful_dataloader). This requires `torchdata` version 0.8.0 or higher that supports StatefulDataLoader to be installed."
  716. },
  717. )
  718. @dataclass
  719. class ProjectConfiguration:
  720. """
  721. Configuration for the Accelerator object based on inner-project needs.
  722. Args:
  723. project_dir (`str`, defaults to `None`):
  724. A path to a directory for storing data.
  725. logging_dir (`str`, defaults to `None`):
  726. A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`.
  727. automatic_checkpoint_naming (`bool`, defaults to `False`):
  728. Whether saved states should be automatically iteratively named.
  729. total_limit (`int`, defaults to `None`):
  730. The maximum number of total saved states to keep.
  731. iteration (`int`, defaults to `0`):
  732. The current save iteration.
  733. save_on_each_node (`bool`, defaults to `False`):
  734. When doing multi-node distributed training, whether to save models and checkpoints on each node, or only on
  735. the main one.
  736. """
  737. project_dir: str = field(default=None, metadata={"help": "A path to a directory for storing data."})
  738. logging_dir: str = field(
  739. default=None,
  740. metadata={
  741. "help": "A path to a directory for storing logs of locally-compatible loggers. If None, defaults to `project_dir`."
  742. },
  743. )
  744. automatic_checkpoint_naming: bool = field(
  745. default=False,
  746. metadata={"help": "Whether saved states should be automatically iteratively named."},
  747. )
  748. total_limit: int = field(
  749. default=None,
  750. metadata={"help": "The maximum number of total saved states to keep."},
  751. )
  752. iteration: int = field(
  753. default=0,
  754. metadata={"help": "The current save iteration."},
  755. )
  756. save_on_each_node: bool = field(
  757. default=False,
  758. metadata={
  759. "help": (
  760. "When doing multi-node distributed training, whether to save models and checkpoints on each node, or"
  761. " only on the main one"
  762. )
  763. },
  764. )
  765. def set_directories(self, project_dir: Optional[str] = None):
  766. "Sets `self.project_dir` and `self.logging_dir` to the appropriate values."
  767. self.project_dir = project_dir
  768. if self.logging_dir is None:
  769. self.logging_dir = project_dir
  770. def __post_init__(self):
  771. self.set_directories(self.project_dir)
  772. @dataclass
  773. class GradientAccumulationPlugin(KwargsHandler):
  774. """
  775. A plugin to configure gradient accumulation behavior. You can only pass one of `gradient_accumulation_plugin` or
  776. `gradient_accumulation_steps` to [`Accelerator`]. Passing both raises an error.
  777. Parameters:
  778. num_steps (`int`):
  779. The number of steps to accumulate gradients for.
  780. adjust_scheduler (`bool`, *optional*, defaults to `True`):
  781. Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be
  782. `True` if the used scheduler was not adjusted for gradient accumulation.
  783. sync_with_dataloader (`bool`, *optional*, defaults to `True`):
  784. Whether to synchronize setting the gradients when at the end of the dataloader.
  785. sync_each_batch (`bool`, *optional*):
  786. Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory
  787. requirements when using gradient accumulation with distributed training, at expense of speed.
  788. Example:
  789. ```python
  790. from accelerate.utils import GradientAccumulationPlugin
  791. gradient_accumulation_plugin = GradientAccumulationPlugin(num_steps=2)
  792. accelerator = Accelerator(gradient_accumulation_plugin=gradient_accumulation_plugin)
  793. ```
  794. """
  795. num_steps: int = field(
  796. default=None,
  797. metadata={"help": "The number of steps to accumulate gradients for."},
  798. )
  799. adjust_scheduler: bool = field(
  800. default=True,
  801. metadata={
  802. "help": "Whether to adjust the scheduler steps to account for the number of steps being accumulated. Should be `True` if the used scheduler was not adjusted for gradient accumulation."
  803. },
  804. )
  805. sync_with_dataloader: bool = field(
  806. default=True,
  807. metadata={
  808. "help": "Whether to synchronize setting the gradients when at the end of the dataloader. Should only be set to `False` if you know what you're doing."
  809. },
  810. )
  811. sync_each_batch: bool = field(
  812. default=False,
  813. metadata={
  814. "help": "Whether to synchronize setting the gradients at each data batch. Setting to `True` may reduce memory requirements when using gradient accumulation with distributed training, at expense of speed."
  815. },
  816. )
  817. @dataclass
  818. class TorchDynamoPlugin(KwargsHandler):
  819. """
  820. This plugin is used to compile a model with PyTorch 2.0
  821. Args:
  822. backend (`DynamoBackend`, defaults to `None`):
  823. A valid Dynamo backend. See https://pytorch.org/docs/stable/torch.compiler.html for more details.
  824. mode (`str`, defaults to `None`):
  825. Possible options are 'default', 'reduce-overhead' or 'max-autotune'.
  826. fullgraph (`bool`, defaults to `None`):
  827. Whether it is ok to break model into several subgraphs.
  828. dynamic (`bool`, defaults to `None`):
  829. Whether to use dynamic shape for tracing.
  830. options (`Any`, defaults to `None`):
  831. A dictionary of options to pass to the backend.
  832. disable (`bool`, defaults to `False`):
  833. Turn torch.compile() into a no-op for testing
  834. use_regional_compilation (`bool`, defaults to `None`):
  835. Use it to reduce the cold start compilation time of torch.compile() by targeting repeated blocks of the
  836. same class and compiling them sequentially to hit the compiler's cache. For example, in `GPT2LMHeadModel`,
  837. the repeated block/class is `GPT2Block`, and can be accessed as `model.transformer.h[0]`. The rest of the
  838. model (e.g model.lm_head) is compiled separately.
  839. """
  840. backend: DynamoBackend = field(
  841. default=None,
  842. metadata={"help": f"Possible options are {[b.value.lower() for b in DynamoBackend]}"},
  843. )
  844. mode: str = field(
  845. default=None,
  846. metadata={"help": "Possible options are 'default', 'reduce-overhead' or 'max-autotune'"},
  847. )
  848. fullgraph: bool = field(
  849. default=None,
  850. metadata={"help": "Whether it is ok to break model into several subgraphs"},
  851. )
  852. dynamic: bool = field(default=None, metadata={"help": "Whether to use dynamic shape for tracing"})
  853. options: Any = field(
  854. default=None,
  855. metadata={"help": "A dictionary of options to pass to the backend."},
  856. )
  857. disable: bool = field(
  858. default=False,
  859. metadata={"help": "Turn torch.compile() into a no-op for testing"},
  860. )
  861. use_regional_compilation: bool = field(
  862. default=None,
  863. metadata={
  864. "help": (
  865. # https://pytorch.org/tutorials/recipes/regional_compilation.html
  866. "Use it to reduce the cold start compilation time of torch.compile() by targeting repeated "
  867. "blocks of the same class and compiling them sequentially to hit the compiler's cache. For "
  868. "example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be accessed "
  869. "as `model.transformer.h[0]`. The rest of the model (e.g model.lm_head) is compiled separately."
  870. )
  871. },
  872. )
  873. def __post_init__(self):
  874. prefix = "ACCELERATE_DYNAMO_"
  875. if self.backend is None:
  876. self.backend = os.environ.get(prefix + "BACKEND", "no")
  877. self.backend = DynamoBackend(self.backend.upper())
  878. if self.mode is None:
  879. self.mode = os.environ.get(prefix + "MODE", "default")
  880. if self.fullgraph is None:
  881. self.fullgraph = str_to_bool(os.environ.get(prefix + "USE_FULLGRAPH", "False")) == 1
  882. if self.use_regional_compilation is None:
  883. self.use_regional_compilation = (
  884. str_to_bool(os.environ.get(prefix + "USE_REGIONAL_COMPILATION", "False")) == 1
  885. )
  886. if self.dynamic is None and os.environ.get(prefix + "USE_DYNAMIC", None) is not None:
  887. self.dynamic = str_to_bool(os.environ.get(prefix + "USE_DYNAMIC", "False")) == 1
  888. def to_dict(self):
  889. dynamo_config = copy.deepcopy(self.__dict__)
  890. dynamo_config["backend"] = dynamo_config["backend"].value.lower()
  891. return dynamo_config
  892. def to_kwargs(self):
  893. kwargs = super().to_kwargs()
  894. kwargs.pop("use_regional_compilation", None)
  895. return kwargs
  896. @dataclass
  897. class DeepSpeedPlugin:
  898. """
  899. This plugin is used to integrate DeepSpeed.
  900. Args:
  901. hf_ds_config (`Any`, defaults to `None`):
  902. Path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`.
  903. gradient_accumulation_steps (`int`, defaults to `None`):
  904. Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value
  905. from the `Accelerator` directly.
  906. gradient_clipping (`float`, defaults to `None`):
  907. Enable gradient clipping with value.
  908. zero_stage (`int`, defaults to `None`):
  909. Possible options are 0, 1, 2, 3. Default will be taken from environment variable.
  910. is_train_batch_min (`bool`, defaults to `True`):
  911. If both train & eval dataloaders are specified, this will decide the `train_batch_size`.
  912. offload_optimizer_device (`str`, defaults to `None`):
  913. Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3.
  914. offload_param_device (`str`, defaults to `None`):
  915. Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3.
  916. offload_optimizer_nvme_path (`str`, defaults to `None`):
  917. Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.
  918. offload_param_nvme_path (`str`, defaults to `None`):
  919. Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3.
  920. zero3_init_flag (`bool`, defaults to `None`):
  921. Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.
  922. zero3_save_16bit_model (`bool`, defaults to `None`):
  923. Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3.
  924. transformer_moe_cls_names (`str`, defaults to `None`):
  925. Comma-separated list of Transformers MoE layer class names (case-sensitive). For example,
  926. `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention`, `JetMoEBlock`, etc.
  927. enable_msamp (`bool`, defaults to `None`):
  928. Flag to indicate whether to enable MS-AMP backend for FP8 training.
  929. msasmp_opt_level (`Optional[Literal["O1", "O2"]]`, defaults to `None`):
  930. Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one
  931. of ['O1' or 'O2'].
  932. """
  933. hf_ds_config: Any = field(
  934. default=None,
  935. metadata={
  936. "help": "path to DeepSpeed config file or dict or an object of class `accelerate.utils.deepspeed.HfDeepSpeedConfig`."
  937. },
  938. )
  939. gradient_accumulation_steps: int = field(
  940. default=None,
  941. metadata={
  942. "help": "Number of steps to accumulate gradients before updating optimizer states. If not set, will use the value from the `Accelerator` directly."
  943. },
  944. )
  945. gradient_clipping: float = field(default=None, metadata={"help": "Enable gradient clipping with value"})
  946. zero_stage: int = field(
  947. default=None,
  948. metadata={"help": "Possible options are 0,1,2,3; Default will be taken from environment variable"},
  949. )
  950. is_train_batch_min: bool = field(
  951. default=True,
  952. metadata={"help": "If both train & eval dataloaders are specified, this will decide the train_batch_size"},
  953. )
  954. offload_optimizer_device: str = field(
  955. default=None,
  956. metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stages 2 and 3."},
  957. )
  958. offload_param_device: str = field(
  959. default=None,
  960. metadata={"help": "Possible options are none|cpu|nvme. Only applicable with ZeRO Stage 3."},
  961. )
  962. offload_optimizer_nvme_path: str = field(
  963. default=None,
  964. metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."},
  965. )
  966. offload_param_nvme_path: str = field(
  967. default=None,
  968. metadata={"help": "Possible options are /nvme|/local_nvme. Only applicable with ZeRO Stage 3."},
  969. )
  970. zero3_init_flag: bool = field(
  971. default=None,
  972. metadata={
  973. "help": "Flag to indicate whether to enable `deepspeed.zero.Init` for constructing massive models."
  974. "Only applicable with ZeRO Stage-3."
  975. },
  976. )
  977. zero3_save_16bit_model: bool = field(
  978. default=None,
  979. metadata={"help": "Flag to indicate whether to save 16-bit model. Only applicable with ZeRO Stage-3."},
  980. )
  981. transformer_moe_cls_names: str = field(
  982. default=None,
  983. metadata={
  984. "help": "comma-separated list of transformers MoE layer class names (case-sensitive), e.g : "
  985. " `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..."
  986. },
  987. )
  988. enable_msamp: bool = field(
  989. default=None,
  990. metadata={"help": "Flag to indicate whether to enable MS-AMP backend for FP8 training."},
  991. )
  992. msamp_opt_level: Optional[Literal["O1", "O2"]] = field(
  993. default=None,
  994. metadata={
  995. "help": "Optimization level for MS-AMP (defaults to 'O1'). Only applicable if `enable_msamp` is True. Should be one of ['O1' or 'O2']."
  996. },
  997. )
  998. def __post_init__(self):
  999. from .deepspeed import HfDeepSpeedConfig
  1000. if self.gradient_accumulation_steps is None:
  1001. gas = os.environ.get("ACCELERATE_GRADIENT_ACCUMULATION_STEPS", "auto")
  1002. self.gradient_accumulation_steps = int(gas) if gas.isdigit() else gas
  1003. if self.gradient_clipping is None:
  1004. gradient_clipping = os.environ.get("ACCELERATE_GRADIENT_CLIPPING", "auto")
  1005. self.gradient_clipping = gradient_clipping if gradient_clipping == "auto" else float(gradient_clipping)
  1006. if self.zero_stage is None:
  1007. self.zero_stage = int(os.environ.get("ACCELERATE_DEEPSPEED_ZERO_STAGE", 2))
  1008. if self.offload_optimizer_device is None:
  1009. self.offload_optimizer_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE", "none")
  1010. if self.offload_param_device is None:
  1011. self.offload_param_device = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE", "none")
  1012. if self.offload_optimizer_nvme_path is None:
  1013. self.offload_optimizer_nvme_path = os.environ.get(
  1014. "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH", "none"
  1015. )
  1016. if self.offload_param_nvme_path is None:
  1017. self.offload_param_nvme_path = os.environ.get("ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH", "none")
  1018. if self.zero3_save_16bit_model is None:
  1019. self.zero3_save_16bit_model = (
  1020. os.environ.get("ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL", "false").lower() == "true"
  1021. )
  1022. if self.enable_msamp is None:
  1023. self.enable_msamp = os.environ.get("ACCELERATE_FP8_BACKEND", None) == "MSAMP"
  1024. if self.msamp_opt_level is None:
  1025. self.msamp_opt_level = os.environ.get("ACCELERATE_FP8_OPT_LEVEL", "O1")
  1026. if self.hf_ds_config is None:
  1027. self.hf_ds_config = os.environ.get("ACCELERATE_DEEPSPEED_CONFIG_FILE", "none")
  1028. if (
  1029. isinstance(self.hf_ds_config, dict)
  1030. or (isinstance(self.hf_ds_config, str) and self.hf_ds_config != "none")
  1031. or isinstance(self.hf_ds_config, HfDeepSpeedConfig)
  1032. ):
  1033. if not isinstance(self.hf_ds_config, HfDeepSpeedConfig):
  1034. self.hf_ds_config = HfDeepSpeedConfig(self.hf_ds_config)
  1035. if "gradient_accumulation_steps" not in self.hf_ds_config.config:
  1036. self.hf_ds_config.config["gradient_accumulation_steps"] = 1
  1037. if "zero_optimization" not in self.hf_ds_config.config:
  1038. raise ValueError("Please specify the ZeRO optimization config in the DeepSpeed config.")
  1039. self._deepspeed_config_checks()
  1040. plugin_to_config_mapping = {
  1041. "gradient_accumulation_steps": "gradient_accumulation_steps",
  1042. "gradient_clipping": "gradient_clipping",
  1043. "zero_stage": "zero_optimization.stage",
  1044. "offload_optimizer_device": "zero_optimization.offload_optimizer.device",
  1045. "offload_param_device": "zero_optimization.offload_param.device",
  1046. "offload_param_nvme_path": "zero_optimization.offload_param.nvme_path",
  1047. "offload_optimizer_nvme_path": "zero_optimization.offload_optimizer.nvme_path",
  1048. "zero3_save_16bit_model": "zero_optimization.stage3_gather_16bit_weights_on_model_save",
  1049. }
  1050. kwargs = {v: getattr(self, k) for k, v in plugin_to_config_mapping.items() if getattr(self, k) is not None}
  1051. for key in kwargs.keys():
  1052. self.fill_match(key, **kwargs, must_match=False)
  1053. self.hf_ds_config.set_stage_and_offload()
  1054. # filling the missing values in the class attributes from the DeepSpeed config
  1055. # when using the DeepSpeed config file.
  1056. for key, value in plugin_to_config_mapping.items():
  1057. config_value = self.hf_ds_config.get_value(value)
  1058. if config_value is not None and config_value != "auto":
  1059. setattr(self, key, config_value)
  1060. else:
  1061. config = {
  1062. "train_batch_size": "auto",
  1063. "train_micro_batch_size_per_gpu": "auto",
  1064. "gradient_accumulation_steps": self.gradient_accumulation_steps,
  1065. "zero_optimization": {
  1066. "stage": self.zero_stage,
  1067. "offload_optimizer": {
  1068. "device": self.offload_optimizer_device,
  1069. "nvme_path": (
  1070. self.offload_optimizer_nvme_path if self.offload_optimizer_device == "nvme" else None
  1071. ),
  1072. },
  1073. "offload_param": {
  1074. "device": self.offload_param_device,
  1075. "nvme_path": (self.offload_param_nvme_path if self.offload_param_device == "nvme" else None),
  1076. },
  1077. "stage3_gather_16bit_weights_on_model_save": self.zero3_save_16bit_model,
  1078. },
  1079. }
  1080. if self.gradient_clipping:
  1081. config["gradient_clipping"] = self.gradient_clipping
  1082. self.hf_ds_config = HfDeepSpeedConfig(config)
  1083. self.deepspeed_config = self.hf_ds_config.config
  1084. self.deepspeed_config["steps_per_print"] = float("inf") # this will stop deepspeed from logging @ stdout
  1085. if self.zero3_init_flag is None:
  1086. self.zero3_init_flag = (
  1087. str_to_bool(
  1088. os.environ.get(
  1089. "ACCELERATE_DEEPSPEED_ZERO3_INIT",
  1090. str(self.hf_ds_config.is_zero3()),
  1091. )
  1092. )
  1093. == 1
  1094. )
  1095. if self.zero3_init_flag and not self.hf_ds_config.is_zero3():
  1096. warnings.warn("DeepSpeed Zero3 Init flag is only applicable for ZeRO Stage 3. Setting it to False.")
  1097. self.zero3_init_flag = False
  1098. # NOTE: Set to False by default, will be set to `True` automatically if it's the first plugin passed
  1099. # to the `Accelerator`'s `deepspeed_plugin` param, *or* `AcceleratorState().enable_deepspeed_plugin(plugin_key)` is manually called
  1100. self._set_selected(False)
  1101. # Ignore if it's already set
  1102. if self.enable_msamp and "msamp" not in self.deepspeed_config:
  1103. if self.zero_stage == 3:
  1104. raise NotImplementedError(
  1105. "MS-AMP is not supported for ZeRO Stage 3. Please use ZeRO Stage 0, 1, or 2 instead."
  1106. )
  1107. if self.msamp_opt_level not in ["O1", "O2"]:
  1108. raise ValueError("Invalid optimization level for MS-AMP. Please use one of ['O1' or'O2'].")
  1109. self.deepspeed_config["msamp"] = {
  1110. "enabled": True,
  1111. "opt_level": self.msamp_opt_level,
  1112. }
  1113. def fill_match(self, ds_key_long, mismatches=None, must_match=True, **kwargs):
  1114. mismatches = [] if mismatches is None else mismatches
  1115. config, ds_key = self.hf_ds_config.find_config_node(ds_key_long)
  1116. if config is None:
  1117. return
  1118. if config.get(ds_key) == "auto":
  1119. if ds_key_long in kwargs:
  1120. config[ds_key] = kwargs[ds_key_long]
  1121. return
  1122. else:
  1123. raise ValueError(
  1124. f"`{ds_key_long}` not found in kwargs. "
  1125. f"Please specify `{ds_key_long}` without `auto` (set to correct value) in the DeepSpeed config file or "
  1126. "pass it in kwargs."
  1127. )
  1128. if not must_match:
  1129. return
  1130. ds_val = config.get(ds_key)
  1131. if ds_val is not None and ds_key_long in kwargs:
  1132. if ds_val != kwargs[ds_key_long]:
  1133. mismatches.append(f"- ds {ds_key_long}={ds_val} vs arg {ds_key_long}={kwargs[ds_key_long]}")
  1134. def is_auto(self, ds_key_long):
  1135. val = self.hf_ds_config.get_value(ds_key_long)
  1136. if val is None:
  1137. return False
  1138. else:
  1139. return val == "auto"
  1140. def get_value(self, ds_key_long, default=None):
  1141. return self.hf_ds_config.get_value(ds_key_long, default)
  1142. def deepspeed_config_process(self, prefix="", mismatches=None, config=None, must_match=True, **kwargs):
  1143. """Process the DeepSpeed config with the values from the kwargs."""
  1144. mismatches = [] if mismatches is None else mismatches
  1145. if config is None:
  1146. config = self.deepspeed_config
  1147. for key, value in config.items():
  1148. if isinstance(value, dict):
  1149. self.deepspeed_config_process(
  1150. prefix=prefix + key + ".",
  1151. mismatches=mismatches,
  1152. config=value,
  1153. must_match=must_match,
  1154. **kwargs,
  1155. )
  1156. else:
  1157. self.fill_match(prefix + key, mismatches, must_match=must_match, **kwargs)
  1158. if len(mismatches) > 0 and prefix == "":
  1159. mismatches_msg = "\n".join(mismatches)
  1160. raise ValueError(
  1161. "Please correct the following DeepSpeed config values that mismatch kwargs "
  1162. f" values:\n{mismatches_msg}\nThe easiest method is to set these DeepSpeed config values to 'auto'."
  1163. )
  1164. def set_mixed_precision(self, mixed_precision):
  1165. ds_config = self.deepspeed_config
  1166. kwargs = {
  1167. "fp16.enabled": mixed_precision == "fp16",
  1168. # When training in fp8, we still rely on bf16 autocast for the core mixed precision
  1169. "bf16.enabled": mixed_precision in ("bf16", "fp8"),
  1170. }
  1171. if mixed_precision == "fp16":
  1172. if "fp16" not in ds_config:
  1173. ds_config["fp16"] = {"enabled": True, "auto_cast": True}
  1174. elif mixed_precision in ("bf16", "fp8"):
  1175. if "bf16" not in ds_config:
  1176. ds_config["bf16"] = {"enabled": True}
  1177. if mixed_precision == "fp8" and self.enable_msamp:
  1178. if "msamp" not in ds_config:
  1179. ds_config["msamp"] = {
  1180. "enabled": True,
  1181. "opt_level": self.msamp_opt_level,
  1182. }
  1183. if mixed_precision != "no":
  1184. diff_dtype = "bf16" if mixed_precision == "fp16" else "fp16"
  1185. if str(ds_config.get(diff_dtype, {}).get("enabled", "False")).lower() == "true":
  1186. raise ValueError(
  1187. f"`--mixed_precision` arg cannot be set to `{mixed_precision}` when `{diff_dtype}` is set in the DeepSpeed config file."
  1188. )
  1189. for dtype in ["fp16", "bf16"]:
  1190. if dtype not in ds_config:
  1191. ds_config[dtype] = {"enabled": False}
  1192. self.fill_match("fp16.enabled", must_match=False, **kwargs)
  1193. self.fill_match("bf16.enabled", must_match=False, **kwargs)
  1194. def set_deepspeed_weakref(self):
  1195. from .imports import is_transformers_available
  1196. ds_config = copy.deepcopy(self.deepspeed_config)
  1197. if self.zero3_init_flag:
  1198. if not is_transformers_available():
  1199. raise Exception(
  1200. "When `zero3_init_flag` is set, it requires Transformers to be installed. "
  1201. "Please run `pip install transformers`."
  1202. )
  1203. if "gradient_accumulation_steps" not in ds_config or ds_config["gradient_accumulation_steps"] == "auto":
  1204. ds_config["gradient_accumulation_steps"] = 1
  1205. if "train_micro_batch_size_per_gpu" not in ds_config or ds_config["train_micro_batch_size_per_gpu"] == "auto":
  1206. ds_config["train_micro_batch_size_per_gpu"] = 1
  1207. if ds_config.get("train_batch_size", None) == "auto":
  1208. del ds_config["train_batch_size"]
  1209. if compare_versions("transformers", "<", "4.46"):
  1210. from transformers.deepspeed import (
  1211. HfDeepSpeedConfig,
  1212. unset_hf_deepspeed_config,
  1213. )
  1214. else:
  1215. from transformers.integrations import (
  1216. HfDeepSpeedConfig,
  1217. unset_hf_deepspeed_config,
  1218. )
  1219. unset_hf_deepspeed_config()
  1220. self.dschf = HfDeepSpeedConfig(ds_config) # keep this object alive # noqa
  1221. def is_zero3_init_enabled(self):
  1222. return self.zero3_init_flag
  1223. @contextmanager
  1224. def zero3_init_context_manager(self, enable=False):
  1225. old = self.zero3_init_flag
  1226. if old == enable:
  1227. yield
  1228. else:
  1229. self.zero3_init_flag = enable
  1230. self.dschf = None
  1231. self.set_deepspeed_weakref()
  1232. yield
  1233. self.zero3_init_flag = old
  1234. self.dschf = None
  1235. self.set_deepspeed_weakref()
  1236. def _deepspeed_config_checks(self):
  1237. env_variable_names_to_ignore = [
  1238. "ACCELERATE_GRADIENT_ACCUMULATION_STEPS",
  1239. "ACCELERATE_GRADIENT_CLIPPING",
  1240. "ACCELERATE_DEEPSPEED_ZERO_STAGE",
  1241. "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_DEVICE",
  1242. "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_DEVICE",
  1243. "ACCELERATE_DEEPSPEED_OFFLOAD_PARAM_NVME_PATH",
  1244. "ACCELERATE_DEEPSPEED_OFFLOAD_OPTIMIZER_NVME_PATH",
  1245. "ACCELERATE_DEEPSPEED_ZERO3_SAVE_16BIT_MODEL",
  1246. "ACCELERATE_MIXED_PRECISION",
  1247. ]
  1248. env_variable_names_to_ignore = [
  1249. name.replace("ACCELERATE_", "").replace("DEEPSPEED_", "").lower() for name in env_variable_names_to_ignore
  1250. ]
  1251. deepspeed_fields_from_accelerate_config = os.environ.get("ACCELERATE_CONFIG_DS_FIELDS", "").split(",")
  1252. if any(name in env_variable_names_to_ignore for name in deepspeed_fields_from_accelerate_config):
  1253. raise ValueError(
  1254. f"When using `deepspeed_config_file`, the following accelerate config variables will be ignored: {env_variable_names_to_ignore}.\n"
  1255. "Please specify them appropriately in the DeepSpeed config file.\n"
  1256. "If you are using an accelerate config file, remove others config variables mentioned in the above specified list.\n"
  1257. "The easiest method is to create a new config following the questionnaire via `accelerate config`.\n"
  1258. "It will only ask for the necessary config variables when using `deepspeed_config_file`."
  1259. )
  1260. def set_moe_leaf_modules(self, model):
  1261. if self.transformer_moe_cls_names is None:
  1262. self.transformer_moe_cls_names = os.environ.get("ACCELERATE_DEEPSPEED_MOE_LAYER_CLS_NAMES", None)
  1263. if self.transformer_moe_cls_names is not None:
  1264. if compare_versions("deepspeed", "<", "0.14.0"):
  1265. raise ImportError("DeepSpeed version must be >= 0.14.0 to use MOE support. Please update DeepSpeed.")
  1266. from deepspeed.utils import set_z3_leaf_modules
  1267. class_names = self.transformer_moe_cls_names.split(",")
  1268. transformer_moe_cls = []
  1269. for layer_class in class_names:
  1270. transformer_cls = get_module_class_from_name(model, layer_class)
  1271. if transformer_cls is None:
  1272. raise Exception(
  1273. f"Could not find a transformer layer class called '{layer_class}' to wrap in the model."
  1274. )
  1275. else:
  1276. transformer_moe_cls.append(transformer_cls)
  1277. set_z3_leaf_modules(model, transformer_moe_cls) # z3_leaf
  1278. def select(self, _from_accelerator_state: bool = False):
  1279. """
  1280. Sets the HfDeepSpeedWeakref to use the current deepspeed plugin configuration
  1281. """
  1282. if not _from_accelerator_state:
  1283. raise ValueError(
  1284. "A `DeepSpeedPlugin` object must be enabled manually by calling `AcceleratorState().enable_deepspeed_plugin(plugin_key)`."
  1285. )
  1286. self.set_deepspeed_weakref()
  1287. self._set_selected(True)
  1288. def _unselect(self):
  1289. self._set_selected(False)
  1290. def _set_selected(self, value: bool):
  1291. """
  1292. Private setter for the 'enabled' attribute.
  1293. """
  1294. self._selected = value
  1295. @property
  1296. def selected(self):
  1297. return self._selected
  1298. @selected.setter
  1299. def selected(self, value):
  1300. raise NotImplementedError(
  1301. "'enabled' can only be set through calling 'AcceleratorState().enable_deepspeed_plugin(key)'."
  1302. )
  1303. @dataclass
  1304. class FullyShardedDataParallelPlugin:
  1305. """
  1306. This plugin is used to enable fully sharded data parallelism.
  1307. Args:
  1308. fsdp_version (`int`, defaults to `1`):
  1309. The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to
  1310. FSDP2 format.
  1311. sharding_strategy (`Union[str, torch.distributed.fsdp.ShardingStrategy]`, defaults to `'FULL_SHARD'`):
  1312. Sharding strategy to use. Should be either a `str` or an instance of
  1313. `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Is deprecated in favor of
  1314. `reshard_after_forward`.
  1315. reshard_after_forward (`Union[str, torch.distributed.fsdp.ShardingStrategy, bool]`, defaults to `'FULL_SHARD'` for `fsdp_version=1` and `True` for `fsdp_version=2`):
  1316. Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of
  1317. `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`.
  1318. backward_prefetch (`Union[str, torch.distributed.fsdp.BackwardPrefetch]`, defaults to `'NO_PREFETCH'`):
  1319. Backward prefetch strategy to use. Should be either a `str` or an instance of
  1320. `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`.
  1321. mixed_precision_policy (`Optional[Union[dict, str, torch.distributed.fsdp.MixedPrecision, torch.distributed.fsdp.MixedPrecisionPolicy]]`, defaults to `None`):
  1322. A config to enable mixed precision training with FullyShardedDataParallel. If passing in a `dict`, it
  1323. should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`, can be an instance of
  1324. `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2. If passing in a `str`, it
  1325. should be one of the following values: fp8, fp16, bf16, fp32, and used to set `param_dtype`,
  1326. `reduce_dtype`, and `buffer_dtype`.
  1327. auto_wrap_policy (`Optional(Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]), defaults to `NO_WRAP`):
  1328. A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one
  1329. of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. See
  1330. `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like.
  1331. cpu_offload (`Union[bool, torch.distributed.fsdp.CPUOffload, torch.distributed.fsdp.CPUOffloadPolicy]`, defaults to `False`):
  1332. Whether to offload parameters to CPU. Should be either a `bool` or an instance of
  1333. `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or
  1334. `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2.
  1335. ignored_modules (`Optional[Union[Iterable[torch.nn.Module], str]]`, defaults to `None`):
  1336. A list of modules to ignore when wrapping with FSDP. When passing a string, will match the modules by name
  1337. using regex fullmatch. If `fsdp_version` is set to 2, the modules are converted to parameters and used.
  1338. state_dict_type (`Union[str, torch.distributed.fsdp.StateDictType]`, defaults to `'FULL_STATE_DICT'`):
  1339. State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or
  1340. `sharded_state_dict`.
  1341. state_dict_config (`Optional[Union[torch.distributed.fsdp.FullStateDictConfig, torch.distributed.fsdp.ShardedStateDictConfig]`, defaults to `None`):
  1342. State dict config to use. Is determined based on the `state_dict_type` if not passed in.
  1343. optim_state_dict_config (`Optional[Union[torch.distributed.fsdp.FullOptimStateDictConfig, torch.distributed.fsdp.ShardedOptimStateDictConfig]`, defaults to `None`):
  1344. Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in.
  1345. limit_all_gathers (`bool`, defaults to `True`):
  1346. Whether to have FSDP explicitly synchronizes the CPU thread to prevent too many in-flight all-gathers. This
  1347. bool only affects the sharded strategies that schedule all-gathers. Enabling this can help lower the number
  1348. of CUDA malloc retries.
  1349. use_orig_params (`bool`, defaults to `False`):
  1350. Whether to use the original parameters for the optimizer.
  1351. param_init_fn (`Optional[Callable[[torch.nn.Module], None]`, defaults to `None`):
  1352. A `Callable[torch.nn.Module] -> None` that specifies how modules that are currently on the meta device
  1353. should be initialized onto an actual device. Only applicable when `sync_module_states` is `True`. By
  1354. default is a `lambda` which calls `to_empty` on the module.
  1355. sync_module_states (`bool`, defaults to `False`):
  1356. Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 to ensure they
  1357. are the same across all ranks after initialization. Defaults to `False` unless `cpu_ram_efficient_loading`
  1358. is `True`, then will be forcibly enabled.
  1359. forward_prefetch (`bool`, defaults to `False`):
  1360. Whether to have FSDP explicitly prefetches the next upcoming all-gather while executing in the forward
  1361. pass. only use with Static graphs.
  1362. activation_checkpointing (`bool`, defaults to `False`):
  1363. A technique to reduce memory usage by clearing activations of certain layers and recomputing them during a
  1364. backward pass. Effectively, this trades extra computation time for reduced memory usage.
  1365. cpu_ram_efficient_loading (`bool`, defaults to `None`):
  1366. If True, only the first process loads the pretrained model checkoint while all other processes have empty
  1367. weights. Only applicable for Transformers. When using this, `sync_module_states` needs to be `True`.
  1368. transformer_cls_names_to_wrap (`Optional[List[str]]`, defaults to `None`):
  1369. A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is
  1370. `transformer_based_wrap`.
  1371. min_num_params (`Optional[int]`, defaults to `None`):
  1372. The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy`
  1373. is `size_based_wrap`.
  1374. """
  1375. fsdp_version: int = field(
  1376. default=None,
  1377. metadata={
  1378. "help": "The version of FSDP to use. Defaults to 1. If set to 2, launcher expects the config to be converted to FSDP2 format."
  1379. },
  1380. )
  1381. sharding_strategy: Union[str, "torch.distributed.fsdp.ShardingStrategy"] = field(
  1382. default=None,
  1383. metadata={
  1384. "help": "Sharding strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'. Is deprecated in favor of `reshard_after_forward` "
  1385. },
  1386. )
  1387. reshard_after_forward: Union[str, "torch.distributed.fsdp.ShardingStrategy", bool] = field(
  1388. default=None,
  1389. metadata={
  1390. "help": "Sharding strategy to use. Should be a bool if `fsdp_version` is set to 2 else a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`. Defaults to 'FULL_SHARD'"
  1391. },
  1392. )
  1393. backward_prefetch: Optional[Union[str, "torch.distributed.fsdp.BackwardPrefetch"]] = field(
  1394. default=None,
  1395. metadata={
  1396. "help": "Backward prefetch strategy to use. Should be either a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.BackwardPrefetch`. Defaults to 'NO_PREFETCH'. This becomes obsolete in FSDP2."
  1397. },
  1398. )
  1399. mixed_precision_policy: Optional[
  1400. Union[
  1401. dict,
  1402. str,
  1403. "torch.distributed.fsdp.MixedPrecision",
  1404. "torch.distributed.fsdp.MixedPrecisionPolicy",
  1405. ]
  1406. ] = field(
  1407. default=None,
  1408. metadata={
  1409. "help": "A config to enable mixed precision training with FullyShardedDataParallel. "
  1410. "If passing in a `dict`, it should have the following keys: `param_dtype`, `reduce_dtype`, and `buffer_dtype`."
  1411. "Can also be an instance of `torch.distributed.fsdp.MixedPrecisionPolicy` if `fsdp_version` is set to 2."
  1412. },
  1413. )
  1414. auto_wrap_policy: Optional[Union[Callable, Literal["transformer_based_wrap", "size_based_wrap", "no_wrap"]]] = (
  1415. field(
  1416. default=None,
  1417. metadata={
  1418. "help": "A callable or string specifying a policy to recursively wrap layers with FSDP. If a string, it must be one of `transformer_based_wrap`, `size_based_wrap`, or `no_wrap`. "
  1419. "Defaults to `NO_WRAP`. See `torch.distributed.fsdp.wrap.size_based_wrap_policy` for a direction on what it should look like"
  1420. },
  1421. )
  1422. )
  1423. cpu_offload: Union[
  1424. bool,
  1425. "torch.distributed.fsdp.CPUOffload",
  1426. "torch.distributed.fsdp.CPUOffloadPolicy",
  1427. ] = field(
  1428. default=None,
  1429. metadata={
  1430. "help": "Whether to offload parameters to CPU. Should be either a `bool` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffload` or `torch.distributed.fsdp.fully_sharded_data_parallel.CPUOffloadPolicy` if `fsdp_version` is set to 2. Defaults to `False`"
  1431. },
  1432. )
  1433. ignored_modules: Optional[Union[Iterable[torch.nn.Module], str]] = field(
  1434. default=None,
  1435. metadata={"help": "A list of modules to ignore when wrapping with FSDP."},
  1436. )
  1437. state_dict_type: Union[str, "torch.distributed.fsdp.StateDictType"] = field(
  1438. default=None,
  1439. metadata={
  1440. "help": "State dict type to use. If a string, it must be one of `full_state_dict`, `local_state_dict`, or `sharded_state_dict`. Defaults to `FULL_STATE_DICT`"
  1441. },
  1442. )
  1443. state_dict_config: Optional[
  1444. Union[
  1445. "torch.distributed.fsdp.FullStateDictConfig",
  1446. "torch.distributed.fsdp.ShardedStateDictConfig",
  1447. ]
  1448. ] = field(
  1449. default=None,
  1450. metadata={"help": "State dict config to use. Is determined based on the `state_dict_type` if not passed in."},
  1451. )
  1452. optim_state_dict_config: Optional[
  1453. Union[
  1454. "torch.distributed.fsdp.FullOptimStateDictConfig",
  1455. "torch.distributed.fsdp.ShardedOptimStateDictConfig",
  1456. ]
  1457. ] = field(
  1458. default=None,
  1459. metadata={
  1460. "help": "Optim state dict config to use. Is determined based on the `state_dict_type` if not passed in."
  1461. },
  1462. )
  1463. limit_all_gathers: bool = field(
  1464. default=True,
  1465. metadata={
  1466. "help": "Whether to have FSDP explicitly synchronizes the CPU thread to prevent "
  1467. "too many in-flight all-gathers. This bool only affects the sharded strategies that schedule all-gathers. "
  1468. "Enabling this can help lower the number of CUDA malloc retries."
  1469. },
  1470. )
  1471. use_orig_params: Optional[bool] = field(
  1472. default=None,
  1473. metadata={
  1474. "help": "Whether to use the original parameters for the optimizer. Defaults to `False`. This becomes obsolete in FSDP2."
  1475. },
  1476. )
  1477. param_init_fn: Optional[Callable[[torch.nn.Module], None]] = field(
  1478. default=None,
  1479. metadata={
  1480. "help": "A Callable[torch.nn.Module] -> None that specifies how modules "
  1481. "that are currently on the meta device should be initialized onto an actual device. "
  1482. "Only applicable when `sync_module_states` is `True`. By default is a `lambda` which calls `to_empty` on the module."
  1483. },
  1484. )
  1485. sync_module_states: Optional[bool] = field(
  1486. default=None,
  1487. metadata={
  1488. "help": "Whether each individually wrapped FSDP unit should broadcast module parameters from rank 0 "
  1489. "to ensure they are the same across all ranks after initialization. Defaults to `False` unless "
  1490. "`cpu_ram_efficient_loading` is `True`, then will be forcibly enabled. This becomes obsolete in FSDP2."
  1491. },
  1492. )
  1493. forward_prefetch: bool = field(
  1494. default=None,
  1495. metadata={
  1496. "help": "Whether to have FSDP explicitly prefetches the next upcoming "
  1497. "all-gather while executing in the forward pass. only use with Static graphs. Defaults to `False`"
  1498. },
  1499. )
  1500. activation_checkpointing: bool = field(
  1501. default=None,
  1502. metadata={
  1503. "help": "A technique to reduce memory usage by clearing activations of "
  1504. "certain layers and recomputing them during a backward pass. Effectively, this trades extra computation time "
  1505. "for reduced memory usage. Defaults to `False`"
  1506. },
  1507. )
  1508. cpu_ram_efficient_loading: bool = field(
  1509. default=None,
  1510. metadata={
  1511. "help": "If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. "
  1512. "Only applicable for 🤗 Transformers. When using this, `sync_module_states` needs to be `True`. Defaults to `False`."
  1513. },
  1514. )
  1515. transformer_cls_names_to_wrap: Optional[list[str]] = field(
  1516. default=None,
  1517. metadata={
  1518. "help": "A list of transformer layer class names to wrap. Only applicable when `auto_wrap_policy` is `transformer_based_wrap`."
  1519. },
  1520. )
  1521. min_num_params: Optional[int] = field(
  1522. default=None,
  1523. metadata={
  1524. "help": "The minimum number of parameters a module must have to be wrapped. Only applicable when `auto_wrap_policy` is `size_based_wrap`."
  1525. },
  1526. )
  1527. def __post_init__(self):
  1528. from torch.distributed.fsdp import BackwardPrefetch, ShardingStrategy
  1529. _fsdp2_warnings = set()
  1530. env_prefix = "FSDP_"
  1531. # Strategy: By default we should always assume that values are passed in, else we check the environment variables
  1532. if self.fsdp_version is None:
  1533. self.fsdp_version = int(os.environ.get(env_prefix + "VERSION", "1"))
  1534. if self.fsdp_version == 2:
  1535. if not is_torch_version(">=", FSDP2_PYTORCH_VERSION):
  1536. raise ImportError(f"FSDP2 requires PyTorch >= {FSDP2_PYTORCH_VERSION}")
  1537. if self.sharding_strategy is not None:
  1538. # We cannot properly detect all of the cases, as by default `args.fsdp_sharding_strategy` is set to `fully_shard`
  1539. # Therefore we issue a warning only if the user has explicitly set it inside their plugin
  1540. _fsdp2_warnings.add(
  1541. "sharding_strategy is deprecated in favor of reshard_after_forward. "
  1542. "This will be removed in a future version of Accelerate."
  1543. )
  1544. if self.fsdp_version == 1:
  1545. if self.sharding_strategy is None:
  1546. self.sharding_strategy = os.environ.get(env_prefix + "SHARDING_STRATEGY", "FULL_SHARD")
  1547. if isinstance(self.sharding_strategy, str):
  1548. if self.sharding_strategy.upper() in FSDP_SHARDING_STRATEGY:
  1549. self.sharding_strategy = FSDP_SHARDING_STRATEGY.index(self.sharding_strategy.upper()) + 1
  1550. if isinstance(self.sharding_strategy, int) or self.sharding_strategy.isdigit():
  1551. self.sharding_strategy = ShardingStrategy(int(self.sharding_strategy))
  1552. else:
  1553. self.sharding_strategy = ShardingStrategy[self.sharding_strategy.upper()]
  1554. # Fallback to `reshard_after_forward` in FSDP1 if `sharding_strategy` is not set
  1555. if self.reshard_after_forward is None and self.sharding_strategy is None:
  1556. reshard_after_forward = os.environ.get(
  1557. env_prefix + "RESHARD_AFTER_FORWARD",
  1558. "true" if self.fsdp_version == 2 else "FULL_SHARD",
  1559. )
  1560. if self.fsdp_version == 2:
  1561. self.reshard_after_forward = str_to_bool(reshard_after_forward.lower(), to_bool=True)
  1562. else:
  1563. self.reshard_after_forward = reshard_after_forward
  1564. if isinstance(self.reshard_after_forward, str):
  1565. if self.fsdp_version == 2:
  1566. self.reshard_after_forward = str_to_bool(self.reshard_after_forward.lower(), to_bool=True)
  1567. else:
  1568. # We need to remap based on custom enum values for user readability
  1569. if self.reshard_after_forward.upper() in FSDP_SHARDING_STRATEGY:
  1570. self.reshard_after_forward = FSDP_SHARDING_STRATEGY.index(self.reshard_after_forward.upper()) + 1
  1571. if isinstance(self.reshard_after_forward, int) or self.reshard_after_forward.isdigit():
  1572. self.reshard_after_forward = ShardingStrategy(int(self.reshard_after_forward))
  1573. else:
  1574. self.reshard_after_forward = ShardingStrategy[self.reshard_after_forward.upper()]
  1575. if self.fsdp_version == 2 and not isinstance(self.reshard_after_forward, bool):
  1576. raise ValueError(
  1577. f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP2, please set to a `bool`"
  1578. )
  1579. if self.fsdp_version == 1 and isinstance(self.reshard_after_forward, bool):
  1580. raise ValueError(
  1581. f"reshard_after_forward set to {self.reshard_after_forward}. This is not supported with FSDP1, please set to a `str` or an instance of `torch.distributed.fsdp.fully_sharded_data_parallel.ShardingStrategy`"
  1582. )
  1583. if self.cpu_offload is None:
  1584. self.cpu_offload = str_to_bool(os.environ.get(env_prefix + "OFFLOAD_PARAMS", "False")) == 1
  1585. self.set_cpu_offload() # abstracted away to hide imports due to version checks
  1586. self.validate_cpu_offload()
  1587. if self.backward_prefetch is None:
  1588. self.backward_prefetch = os.environ.get(env_prefix + "BACKWARD_PREFETCH", None)
  1589. if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() == "NO_PREFETCH":
  1590. self.backward_prefetch = None
  1591. if self.backward_prefetch is not None and not isinstance(self.backward_prefetch, BackwardPrefetch):
  1592. if isinstance(self.backward_prefetch, str) and self.backward_prefetch.upper() in FSDP_BACKWARD_PREFETCH:
  1593. self.backward_prefetch = FSDP_BACKWARD_PREFETCH.index(self.backward_prefetch.upper()) + 1
  1594. if isinstance(self.backward_prefetch, int) or self.backward_prefetch.isdigit():
  1595. self.backward_prefetch = BackwardPrefetch(int(self.backward_prefetch))
  1596. else:
  1597. self.backward_prefetch = BackwardPrefetch[self.backward_prefetch.upper()]
  1598. if self.fsdp_version == 2 and self.backward_prefetch is not None:
  1599. _fsdp2_warnings.add("backward_prefetch is not supported in FSDP2. Setting backward prefetch to None.")
  1600. self.backward_prefetch = None
  1601. self.set_state_dict_type()
  1602. if self.auto_wrap_policy is None:
  1603. self.auto_wrap_policy = os.environ.get(env_prefix + "AUTO_WRAP_POLICY", "NO_WRAP")
  1604. if isinstance(self.auto_wrap_policy, str):
  1605. if self.auto_wrap_policy.upper() not in FSDP_AUTO_WRAP_POLICY:
  1606. raise ValueError(
  1607. f"Invalid auto wrap policy: {self.auto_wrap_policy}. Must be one of {FSDP_AUTO_WRAP_POLICY}"
  1608. )
  1609. from torch.distributed.fsdp.wrap import (
  1610. size_based_auto_wrap_policy,
  1611. transformer_auto_wrap_policy,
  1612. )
  1613. if self.auto_wrap_policy.upper() == "TRANSFORMER_BASED_WRAP":
  1614. self.auto_wrap_policy = transformer_auto_wrap_policy
  1615. if self.transformer_cls_names_to_wrap is None:
  1616. self.transformer_cls_names_to_wrap = os.environ.get(env_prefix + "TRANSFORMER_CLS_TO_WRAP", None)
  1617. if isinstance(self.transformer_cls_names_to_wrap, str):
  1618. self.transformer_cls_names_to_wrap = self.transformer_cls_names_to_wrap.split(",")
  1619. elif self.auto_wrap_policy.upper() == "SIZE_BASED_WRAP":
  1620. self.auto_wrap_policy = size_based_auto_wrap_policy
  1621. if self.min_num_params is None:
  1622. self.min_num_params = int(os.environ.get(env_prefix + "MIN_NUM_PARAMS", 0))
  1623. elif not isinstance(self.min_num_params, int):
  1624. raise ValueError(
  1625. f"`min_num_params` must be an integer. Got {self.min_num_params} of type {type(self.min_num_params)}"
  1626. )
  1627. elif self.auto_wrap_policy.upper() == "NO_WRAP":
  1628. self.auto_wrap_policy = None
  1629. if self.use_orig_params is None and self.fsdp_version == 1:
  1630. self.use_orig_params = str_to_bool(os.environ.get(env_prefix + "USE_ORIG_PARAMS", "False")) == 1
  1631. if self.fsdp_version == 2 and self.use_orig_params is not None:
  1632. _fsdp2_warnings.add("use_orig_params is obsolete in FSDP2, as FSDP2 always uses the original parameters.")
  1633. self.use_orig_params = None
  1634. if self.sync_module_states is None and self.fsdp_version == 1:
  1635. self.sync_module_states = str_to_bool(os.environ.get(env_prefix + "SYNC_MODULE_STATES", "False")) == 1
  1636. if self.fsdp_version == 2 and self.sync_module_states is not None:
  1637. _fsdp2_warnings.add(
  1638. "sync_module_states is obsolete in FSDP2, as it is not needed anymore."
  1639. "Setting sync_module_states to None."
  1640. )
  1641. self.sync_module_states = None
  1642. if self.forward_prefetch is None and self.fsdp_version == 1:
  1643. self.forward_prefetch = str_to_bool(os.environ.get(env_prefix + "FORWARD_PREFETCH", "False")) == 1
  1644. if self.fsdp_version == 2 and self.forward_prefetch is not None:
  1645. raise ValueError("forward_prefetch is not yet implemented in FSDP2, set to None or use `fsdp_version=1`")
  1646. if self.activation_checkpointing is None:
  1647. self.activation_checkpointing = (
  1648. str_to_bool(os.environ.get(env_prefix + "ACTIVATION_CHECKPOINTING", "False")) == 1
  1649. )
  1650. if self.ignored_modules is None:
  1651. self.ignored_modules = os.environ.get(env_prefix + "IGNORED_MODULES", None)
  1652. if self.cpu_ram_efficient_loading is None:
  1653. self.cpu_ram_efficient_loading = (
  1654. str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1
  1655. )
  1656. else:
  1657. # We still need to set it for transformers
  1658. os.environ[env_prefix + "CPU_RAM_EFFICIENT_LOADING"] = str(self.cpu_ram_efficient_loading)
  1659. # There's no need to specify sync_module_states in FSDP2
  1660. if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states:
  1661. warnings.warn(
  1662. "sync_module_states cannot be False since efficient cpu ram loading enabled. "
  1663. "Setting sync_module_states to True."
  1664. )
  1665. self.sync_module_states = True
  1666. if isinstance(self.mixed_precision_policy, str):
  1667. # override is True since self.mixed_precision_policy is not None
  1668. # has to be overwritten with the correct mixed precision object
  1669. self.set_mixed_precision(self.mixed_precision_policy, override=True)
  1670. elif isinstance(self.mixed_precision_policy, dict):
  1671. self.set_mixed_precision(self.mixed_precision_policy)
  1672. if self.mixed_precision_policy is not None:
  1673. self.validate_mixed_precision_policy()
  1674. if self.sync_module_states:
  1675. if is_npu_available():
  1676. device = torch.npu.current_device()
  1677. elif is_mlu_available():
  1678. device = torch.mlu.current_device()
  1679. elif is_musa_available():
  1680. device = torch.musa.current_device()
  1681. elif is_cuda_available():
  1682. device = torch.cuda.current_device()
  1683. elif is_xpu_available():
  1684. device = torch.xpu.current_device()
  1685. elif is_hpu_available():
  1686. device = torch.hpu.current_device()
  1687. else:
  1688. raise RuntimeError(
  1689. "There are currently no available devices found, must be one of 'XPU', 'CUDA', 'MLU', 'NPU', 'MUSA', or 'HPU'."
  1690. )
  1691. # Create a function that will be used to initialize the parameters of the model
  1692. # when using `sync_module_states`
  1693. self.param_init_fn = lambda x: x.to_empty(device=device, recurse=False)
  1694. if is_torch_version("<", "2.7.0") and self.fsdp_version == 2 and self.ignored_modules is not None:
  1695. _fsdp2_warnings.add(
  1696. "FSDP2 ignored_params/ignored_modules is not available for torch version < 2.7.0"
  1697. "Setting ignored_modules to None."
  1698. )
  1699. self.ignored_modules = None
  1700. # Single warning for all deprecation warnings due to FSDP2 conversion
  1701. if _fsdp2_warnings:
  1702. logger.warning("Multiple deprecation warnings due to FSDP2 conversion:\n".join(_fsdp2_warnings))
  1703. def set_state_dict_type(self, state_dict_type=None):
  1704. """
  1705. Set the state dict config based on the `StateDictType`.
  1706. """
  1707. from torch.distributed.fsdp.fully_sharded_data_parallel import (
  1708. FullOptimStateDictConfig,
  1709. FullStateDictConfig,
  1710. ShardedOptimStateDictConfig,
  1711. ShardedStateDictConfig,
  1712. StateDictType,
  1713. )
  1714. # Override the state_dict_type if provided, typical use case:
  1715. # user trains with sharded, but final save is with full
  1716. if state_dict_type is not None:
  1717. self.state_dict_type = state_dict_type
  1718. if self.state_dict_type is None:
  1719. self.state_dict_type = os.environ.get(
  1720. "FSDP_STATE_DICT_TYPE",
  1721. "FULL_STATE_DICT" if self.fsdp_version == 1 else "SHARDED_STATE_DICT",
  1722. )
  1723. if isinstance(self.state_dict_type, str):
  1724. if self.state_dict_type.isdigit():
  1725. self.state_dict_type = StateDictType(int(self.state_dict_type))
  1726. else:
  1727. self.state_dict_type = StateDictType[self.state_dict_type.upper()]
  1728. if self.state_dict_type == StateDictType.FULL_STATE_DICT:
  1729. if self.state_dict_config is None:
  1730. self.state_dict_config = FullStateDictConfig(offload_to_cpu=True, rank0_only=True)
  1731. if self.optim_state_dict_config is None:
  1732. self.optim_state_dict_config = FullOptimStateDictConfig(offload_to_cpu=True, rank0_only=True)
  1733. elif self.state_dict_type == StateDictType.SHARDED_STATE_DICT:
  1734. if self.state_dict_config is None:
  1735. self.state_dict_config = ShardedStateDictConfig(offload_to_cpu=True)
  1736. if self.optim_state_dict_config is None:
  1737. self.optim_state_dict_config = ShardedOptimStateDictConfig(offload_to_cpu=True)
  1738. if self.fsdp_version == 2 and self.state_dict_type == StateDictType.LOCAL_STATE_DICT:
  1739. raise ValueError(
  1740. "FSDP2 does not support LOCAL_STATE_DICT. "
  1741. "Please set `fsdp_state_dict_type` to `SHARDED_STATE_DICT` or `FULL_STATE_DICT`."
  1742. )
  1743. def set_auto_wrap_policy(self, model):
  1744. """
  1745. Given `model`, creates an `auto_wrap_policy` based on the passed in policy and if we can use the
  1746. `transformer_cls_to_wrap`
  1747. """
  1748. from torch.distributed.fsdp.wrap import (
  1749. size_based_auto_wrap_policy,
  1750. transformer_auto_wrap_policy,
  1751. )
  1752. # First base off of `_no_split_modules`
  1753. no_split_modules = getattr(model, "_no_split_modules", None)
  1754. default_transformer_cls_names_to_wrap = list(no_split_modules) if no_split_modules is not None else []
  1755. if self.auto_wrap_policy == transformer_auto_wrap_policy:
  1756. if self.transformer_cls_names_to_wrap is None:
  1757. self.transformer_cls_names_to_wrap = default_transformer_cls_names_to_wrap
  1758. transformer_cls_to_wrap = set()
  1759. for layer_class in self.transformer_cls_names_to_wrap:
  1760. transformer_cls = get_module_class_from_name(model, layer_class)
  1761. if transformer_cls is None:
  1762. raise ValueError(f"Could not find the transformer layer class {layer_class} in the model.")
  1763. transformer_cls_to_wrap.add(transformer_cls)
  1764. # Finally we set the auto_wrap_policy to a callable
  1765. self.auto_wrap_policy = functools.partial(
  1766. self.auto_wrap_policy, transformer_layer_cls=transformer_cls_to_wrap
  1767. )
  1768. elif self.auto_wrap_policy == size_based_auto_wrap_policy:
  1769. # If zero, we silently ignore it.
  1770. if self.min_num_params > 0:
  1771. self.auto_wrap_policy = functools.partial(self.auto_wrap_policy, min_num_params=self.min_num_params)
  1772. else:
  1773. self.auto_wrap_policy = None
  1774. def set_mixed_precision(self, mixed_precision, buffer_autocast=False, override=False):
  1775. "Sets the mixed precision policy for FSDP"
  1776. mixed_precision_mapping = {
  1777. "fp8": torch.bfloat16,
  1778. "fp16": torch.float16,
  1779. "bf16": torch.bfloat16,
  1780. "fp32": torch.float32,
  1781. }
  1782. dtype = mixed_precision
  1783. if isinstance(mixed_precision, str):
  1784. dtype = mixed_precision_mapping.get(mixed_precision, None)
  1785. if dtype is None:
  1786. raise ValueError(
  1787. f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.keys())}"
  1788. )
  1789. elif isinstance(mixed_precision, torch.dtype) and mixed_precision not in mixed_precision_mapping.values():
  1790. raise ValueError(
  1791. f"Invalid mixed precision: {mixed_precision}. Must be one of {list(mixed_precision_mapping.values())}"
  1792. )
  1793. buffer_type = torch.float32 if buffer_autocast else dtype
  1794. if self.fsdp_version == 1:
  1795. from torch.distributed.fsdp import MixedPrecision
  1796. elif self.fsdp_version == 2:
  1797. from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision
  1798. if override or self.mixed_precision_policy is None:
  1799. dtype_args = {"param_dtype": dtype, "reduce_dtype": dtype}
  1800. if self.fsdp_version == 1:
  1801. dtype_args["buffer_dtype"] = buffer_type
  1802. else:
  1803. dtype_args["output_dtype"] = dtype
  1804. # TODO(s1ro1): `cast_forward_inputs` for FSDP2?
  1805. self.mixed_precision_policy = MixedPrecision(**dtype_args)
  1806. elif isinstance(self.mixed_precision_policy, dict):
  1807. # Check for incompatible types
  1808. valid_keys = ["param_dtype", "reduce_dtype"] + (
  1809. ["buffer_dtype"] if self.fsdp_version == 1 else ["output_dtype"]
  1810. )
  1811. missing_keys = [k for k in valid_keys if k not in self.mixed_precision_policy]
  1812. invalid_values = [
  1813. k for k, v in self.mixed_precision_policy.items() if v not in mixed_precision_mapping.values()
  1814. ]
  1815. if missing_keys or invalid_values:
  1816. raise ValueError(
  1817. f"Invalid mixed precision policy: {self.mixed_precision_policy}. "
  1818. f"Must be a `dict` with keys {valid_keys}."
  1819. f"Values must be one of {list(mixed_precision_mapping.values())}"
  1820. )
  1821. self.mixed_precision_policy = MixedPrecision(**self.mixed_precision_policy)
  1822. def validate_mixed_precision_policy(self):
  1823. """
  1824. Validates the mixed precision policy, abstracted away to not bring in the imports if not needed.
  1825. """
  1826. if self.fsdp_version == 2:
  1827. from torch.distributed.fsdp import MixedPrecisionPolicy as MixedPrecision
  1828. else:
  1829. from torch.distributed.fsdp import MixedPrecision
  1830. if not isinstance(self.mixed_precision_policy, MixedPrecision):
  1831. required_type = (
  1832. "`torch.distributed.fsdp.MixedPrecisionPolicy`"
  1833. if self.fsdp_version == 2
  1834. else "`torch.distributed.fsdp.MixedPrecision`"
  1835. )
  1836. raise ValueError(f"mixed_precision_policy must be an instance of {required_type}.")
  1837. def set_cpu_offload(self):
  1838. if self.fsdp_version == 2:
  1839. from torch.distributed.fsdp import CPUOffloadPolicy, OffloadPolicy
  1840. else:
  1841. from torch.distributed.fsdp import CPUOffload
  1842. if isinstance(self.cpu_offload, bool):
  1843. if self.fsdp_version == 2:
  1844. if not self.cpu_offload:
  1845. self.cpu_offload = OffloadPolicy()
  1846. else:
  1847. self.cpu_offload = CPUOffloadPolicy()
  1848. else:
  1849. self.cpu_offload = CPUOffload(offload_params=self.cpu_offload)
  1850. def validate_cpu_offload(self):
  1851. if self.fsdp_version == 2:
  1852. from torch.distributed.fsdp import OffloadPolicy
  1853. else:
  1854. from torch.distributed.fsdp import CPUOffload
  1855. if self.fsdp_version == 2 and not isinstance(self.cpu_offload, OffloadPolicy):
  1856. raise ValueError(
  1857. f"`cpu_offload` must be an instance of `torch.distributed.fsdp.OffloadPolicy` in FSDP2, got {self.cpu_offload}"
  1858. )
  1859. if self.fsdp_version == 1 and not isinstance(self.cpu_offload, CPUOffload):
  1860. raise ValueError(
  1861. f"`cpu_offload` must be an instance of `torch.distributed.fsdp.CPUOffload` in FSDP1, got {self.cpu_offload}"
  1862. )
  1863. @dataclass
  1864. class TorchTensorParallelPlugin:
  1865. """
  1866. This plugin is used to enable tensor parallelism using PyTorch >= 2.0.
  1867. """
  1868. tp_size: int = field(
  1869. default=1,
  1870. metadata={"help": "tensor parallel size will be used in the device mesh preparation"},
  1871. )
  1872. # torch_device_mesh is of type "torch.distributed.DeviceMesh"
  1873. torch_device_mesh: Optional["torch.distributed.DeviceMesh"] = field(default=None)
  1874. @dataclass
  1875. class TorchContextParallelConfig:
  1876. """
  1877. This class holds the configuration for context parallelism in PyTorch.
  1878. """
  1879. cp_comm_strategy: Optional[str] = field(
  1880. default=None,
  1881. metadata={
  1882. "help": "Communication strategy for context parallelism. Can be one of 'allgather' or 'alltoall'. Defaults to 'allgather'."
  1883. },
  1884. )
  1885. def __post_init__(self):
  1886. if not is_torch_version(">=", BETA_CP_AVAILABLE_PYTORCH_VERSION):
  1887. raise ValueError(
  1888. f"FSDP2-based Context parallelism is only available in PyTorch {BETA_CP_AVAILABLE_PYTORCH_VERSION} and later versions. "
  1889. "Please upgrade your PyTorch version."
  1890. )
  1891. if self.cp_comm_strategy is None:
  1892. self.cp_comm_strategy = os.environ.get("PARALLELISM_CONFIG_CP_COMM_STRATEGY", "allgather")
  1893. if self.cp_comm_strategy not in ["allgather", "alltoall"]:
  1894. raise ValueError(
  1895. f"Invalid cp_comm_strategy: {self.cp_comm_strategy}. Must be one of 'allgather' or 'alltoall'."
  1896. )
  1897. @dataclass
  1898. class DeepSpeedSequenceParallelConfig:
  1899. sp_seq_length: Optional[int] = field(
  1900. default=None,
  1901. metadata={
  1902. "help": "Sequence length for when batches are all of the same length. For variable sequence lengths across batches set `sp_seq_length_is_variable=True` and leave this field unset"
  1903. },
  1904. )
  1905. sp_seq_length_is_variable: Optional[bool] = field(
  1906. default=None,
  1907. metadata={
  1908. "help": "If `True` will work with a sequence length that may change between batches, in which case `sp_seq_length` value can be set to anything divisible by cp size or remain unset. If `False` then `sp_seq_length` needs to match the batch's sequence length dimension. The default is `True`."
  1909. },
  1910. )
  1911. sp_attn_implementation: Optional[str] = field(
  1912. default=None,
  1913. metadata={
  1914. "help": "Attention implementation to use. Can be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'. Defaults to `sdpa`."
  1915. },
  1916. )
  1917. def __post_init__(self):
  1918. # sp_seq_length_is_variable and sp_seq_length are interconnected
  1919. if self.sp_seq_length_is_variable is None:
  1920. self.sp_seq_length_is_variable = (
  1921. os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH_IS_VARIABLE", "true").lower() == "true"
  1922. )
  1923. if not self.sp_seq_length_is_variable and self.sp_seq_length is None:
  1924. if "PARALLELISM_CONFIG_SP_SEQ_LENGTH" not in os.environ:
  1925. raise ValueError(
  1926. "when `sp_seq_length_is_variable` is `False` `sp_seq_length` must be provided either through the constructor or the environment variable PARALLELISM_CONFIG_SP_SEQ_LENGTH"
  1927. )
  1928. else:
  1929. self.sp_seq_length = os.environ.get("PARALLELISM_CONFIG_SP_SEQ_LENGTH")
  1930. self.sp_seq_length = None if self.sp_seq_length == "None" else int(self.sp_seq_length)
  1931. if self.sp_attn_implementation is None:
  1932. self.sp_attn_implementation = os.environ.get("PARALLELISM_CONFIG_SP_ATTN_IMPLEMENTATION", None)
  1933. if self.sp_attn_implementation is not None and self.sp_attn_implementation not in [
  1934. "flash_attention_2",
  1935. "flash_attention_3",
  1936. "sdpa",
  1937. ]:
  1938. raise ValueError(
  1939. f"Invalid sp_attn_implementation: {self.sp_attn_implementation}. Must be one of 'flash_attention_2', 'flash_attention_3' or 'sdpa'."
  1940. )
  1941. @dataclass
  1942. class TorchTensorParallelConfig:
  1943. """
  1944. Use this object in your [`Accelerator`] to customize your torch tensor parallelism.
  1945. """
  1946. enable_async_tp: bool = False
  1947. def __post_init__(self):
  1948. if not is_torch_version(">=", BETA_TP_AVAILABLE_PYTORCH_VERSION):
  1949. raise ValueError(
  1950. f"Torch tensor parallelism is only available in PyTorch {BETA_TP_AVAILABLE_PYTORCH_VERSION} and later versions. "
  1951. "Please upgrade your PyTorch version."
  1952. )
  1953. if not compare_versions("transformers", ">=", BETA_TP_AVAILABLE_TRANSFORMERS_VERSION):
  1954. raise ValueError(f"TP requires transformers >= {BETA_TP_AVAILABLE_TRANSFORMERS_VERSION}")
  1955. if self.enable_async_tp:
  1956. warnings.warn("Async tensor parallelism is currently not supported, ignoring this option.")
  1957. @dataclass
  1958. class MegatronLMPlugin:
  1959. """
  1960. Plugin for Megatron-LM to enable tensor, pipeline, sequence and data parallelism. Also to enable selective
  1961. activation recomputation and optimized fused kernels.
  1962. Args:
  1963. tp_degree (`int`, defaults to `None`):
  1964. Tensor parallelism degree.
  1965. pp_degree (`int`, defaults to `None`):
  1966. Pipeline parallelism degree.
  1967. num_micro_batches (`int`, defaults to `None`):
  1968. Number of micro-batches.
  1969. gradient_clipping (`float`, defaults to `None`):
  1970. Gradient clipping value based on global L2 Norm (0 to disable).
  1971. sequence_parallelism (`bool`, defaults to `None`):
  1972. Enable sequence parallelism.
  1973. recompute_activations (`bool`, defaults to `None`):
  1974. Enable selective activation recomputation.
  1975. use_distributed_optimizr (`bool`, defaults to `None`):
  1976. Enable distributed optimizer.
  1977. pipeline_model_parallel_split_rank (`int`, defaults to `None`):
  1978. Rank where encoder and decoder should be split.
  1979. num_layers_per_virtual_pipeline_stage (`int`, defaults to `None`):
  1980. Number of layers per virtual pipeline stage.
  1981. is_train_batch_min (`str`, defaults to `True`):
  1982. If both tran & eval dataloaders are specified, this will decide the `micro_batch_size`.
  1983. train_iters (`int`, defaults to `None`):
  1984. Total number of samples to train over all training runs. Note that either train-iters or train-samples
  1985. should be provided when using `MegatronLMDummyScheduler`.
  1986. train_samples (`int`, defaults to `None`):
  1987. Total number of samples to train over all training runs. Note that either train-iters or train-samples
  1988. should be provided when using `MegatronLMDummyScheduler`.
  1989. weight_decay_incr_style (`str`, defaults to `'constant'`):
  1990. Weight decay increment function. choices=["constant", "linear", "cosine"].
  1991. start_weight_decay (`float`, defaults to `None`):
  1992. Initial weight decay coefficient for L2 regularization.
  1993. end_weight_decay (`float`, defaults to `None`):
  1994. End of run weight decay coefficient for L2 regularization.
  1995. lr_decay_style (`str`, defaults to `'linear'`):
  1996. Learning rate decay function. choices=['constant', 'linear', 'cosine'].
  1997. lr_decay_iters (`int`, defaults to `None`):
  1998. Number of iterations for learning rate decay. If None defaults to `train_iters`.
  1999. lr_decay_samples (`int`, defaults to `None`):
  2000. Number of samples for learning rate decay. If None defaults to `train_samples`.
  2001. lr_warmup_iters (`int`, defaults to `None`):
  2002. Number of iterations to linearly warmup learning rate over.
  2003. lr_warmup_samples (`int`, defaults to `None`):
  2004. Number of samples to linearly warmup learning rate over.
  2005. lr_warmup_fraction (`float`, defaults to `None`):
  2006. Fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over.
  2007. min_lr (`float`, defaults to `0`):
  2008. Minimum value for learning rate. The scheduler clip values below this threshold.
  2009. consumed_samples (`List`, defaults to `None`):
  2010. Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call.
  2011. no_wd_decay_cond (`Optional`, defaults to `None`):
  2012. Condition to disable weight decay.
  2013. scale_lr_cond (`Optional`, defaults to `None`):
  2014. Condition to scale learning rate.
  2015. lr_mult (`float`, defaults to `1.0`):
  2016. Learning rate multiplier.
  2017. megatron_dataset_flag (`bool`, defaults to `False`):
  2018. Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format.
  2019. seq_length (`int`, defaults to `None`):
  2020. Maximum sequence length to process.
  2021. encoder_seq_length (`int`, defaults to `None`):
  2022. Maximum sequence length to process for the encoder.
  2023. decoder_seq_length (`int`, defaults to `None`):
  2024. Maximum sequence length to process for the decoder.
  2025. tensorboard_dir (`str`, defaults to `None`):
  2026. Path to save tensorboard logs.
  2027. set_all_logging_options (`bool`, defaults to `False`):
  2028. Whether to set all logging options.
  2029. eval_iters (`int`, defaults to `100`):
  2030. Number of iterations to run for evaluation validation/test for.
  2031. eval_interval (`int`, defaults to `1000`):
  2032. Interval between running evaluation on validation set.
  2033. return_logits (`bool`, defaults to `False`):
  2034. Whether to return logits from the model.
  2035. custom_train_step_class (`Optional`, defaults to `None`):
  2036. Custom train step class.
  2037. custom_train_step_kwargs (`Optional`, defaults to `None`):
  2038. Custom train step kwargs.
  2039. custom_model_provider_function (`Optional`, defaults to `None`):
  2040. Custom model provider function.
  2041. custom_prepare_model_function (`Optional`, defaults to `None`):
  2042. Custom prepare model function.
  2043. custom_megatron_datasets_provider_function (`Optional`, defaults to `None`):
  2044. Custom megatron train_valid_test datasets provider function.
  2045. custom_get_batch_function (`Optional`, defaults to `None`):
  2046. Custom get batch function.
  2047. custom_loss_function (`Optional`, defaults to `None`):
  2048. Custom loss function.
  2049. other_megatron_args (`Optional`, defaults to `None`):
  2050. Other Megatron-LM arguments. Please refer Megatron-LM.
  2051. """
  2052. tp_degree: int = field(default=None, metadata={"help": "tensor parallelism degree."})
  2053. pp_degree: int = field(default=None, metadata={"help": "pipeline parallelism degree."})
  2054. num_micro_batches: int = field(default=None, metadata={"help": "number of micro-batches."})
  2055. gradient_clipping: float = field(
  2056. default=None,
  2057. metadata={"help": "gradient clipping value based on global L2 Norm (0 to disable)"},
  2058. )
  2059. sequence_parallelism: bool = field(
  2060. default=None,
  2061. metadata={"help": "enable sequence parallelism"},
  2062. )
  2063. recompute_activations: bool = field(
  2064. default=None,
  2065. metadata={"help": "enable selective activation recomputation"},
  2066. )
  2067. use_distributed_optimizer: bool = field(
  2068. default=None,
  2069. metadata={"help": "enable distributed optimizer"},
  2070. )
  2071. pipeline_model_parallel_split_rank: int = field(
  2072. default=None,
  2073. metadata={"help": "Rank where encoder and decoder should be split."},
  2074. )
  2075. num_layers_per_virtual_pipeline_stage: int = field(
  2076. default=None, metadata={"help": "Number of layers per virtual pipeline stage."}
  2077. )
  2078. is_train_batch_min: str = field(
  2079. default=True,
  2080. metadata={"help": "If both train & eval dataloaders are specified, this will decide the micro_batch_size"},
  2081. )
  2082. train_iters: int = field(
  2083. default=None,
  2084. metadata={
  2085. "help": "Total number of iterations to train over all training runs. "
  2086. "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`"
  2087. },
  2088. )
  2089. train_samples: int = field(
  2090. default=None,
  2091. metadata={
  2092. "help": "Total number of samples to train over all training runs. "
  2093. "Note that either train-iters or train-samples should be provided when using `MegatronLMDummyScheduler`"
  2094. },
  2095. )
  2096. weight_decay_incr_style: str = field(
  2097. default="constant",
  2098. metadata={"help": 'Weight decay increment function. choices=["constant", "linear", "cosine"]. '},
  2099. )
  2100. start_weight_decay: float = field(
  2101. default=None,
  2102. metadata={"help": "Initial weight decay coefficient for L2 regularization."},
  2103. )
  2104. end_weight_decay: float = field(
  2105. default=None,
  2106. metadata={"help": "End of run weight decay coefficient for L2 regularization."},
  2107. )
  2108. lr_decay_style: str = field(
  2109. default="linear",
  2110. metadata={"help": "Learning rate decay function. choices=['constant', 'linear', 'cosine']."},
  2111. )
  2112. lr_decay_iters: int = field(
  2113. default=None,
  2114. metadata={"help": "Number of iterations for learning rate decay. If None defaults to `train_iters`."},
  2115. )
  2116. lr_decay_samples: int = field(
  2117. default=None,
  2118. metadata={"help": "Number of samples for learning rate decay. If None defaults to `train_samples`."},
  2119. )
  2120. lr_warmup_iters: int = field(
  2121. default=None,
  2122. metadata={"help": "number of iterations to linearly warmup learning rate over."},
  2123. )
  2124. lr_warmup_samples: int = field(
  2125. default=None,
  2126. metadata={"help": "number of samples to linearly warmup learning rate over."},
  2127. )
  2128. lr_warmup_fraction: float = field(
  2129. default=None,
  2130. metadata={"help": "fraction of lr-warmup-(iters/samples) to linearly warmup learning rate over."},
  2131. )
  2132. min_lr: float = field(
  2133. default=0,
  2134. metadata={"help": "Minimum value for learning rate. The scheduler clip values below this threshold."},
  2135. )
  2136. consumed_samples: list[int] = field(
  2137. default=None,
  2138. metadata={
  2139. "help": "Number of samples consumed in the same order as the dataloaders to `accelerator.prepare` call."
  2140. },
  2141. )
  2142. no_wd_decay_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to disable weight decay."})
  2143. scale_lr_cond: Optional[Callable] = field(default=None, metadata={"help": "Condition to scale learning rate."})
  2144. lr_mult: float = field(default=1.0, metadata={"help": "Learning rate multiplier."})
  2145. megatron_dataset_flag: bool = field(
  2146. default=False,
  2147. metadata={"help": "Whether the format of dataset follows Megatron-LM Indexed/Cached/MemoryMapped format."},
  2148. )
  2149. seq_length: int = field(
  2150. default=None,
  2151. metadata={"help": "Maximum sequence length to process."},
  2152. )
  2153. encoder_seq_length: int = field(
  2154. default=None,
  2155. metadata={"help": "Maximum sequence length to process for the encoder."},
  2156. )
  2157. decoder_seq_length: int = field(
  2158. default=None,
  2159. metadata={"help": "Maximum sequence length to process for the decoder."},
  2160. )
  2161. tensorboard_dir: str = field(
  2162. default=None,
  2163. metadata={"help": "Path to save tensorboard logs."},
  2164. )
  2165. set_all_logging_options: bool = field(
  2166. default=False,
  2167. metadata={"help": "Whether to set all logging options."},
  2168. )
  2169. eval_iters: int = field(
  2170. default=100,
  2171. metadata={"help": "Number of iterations to run for evaluation validation/test for."},
  2172. )
  2173. eval_interval: int = field(
  2174. default=1000,
  2175. metadata={"help": "Interval between running evaluation on validation set."},
  2176. )
  2177. return_logits: bool = field(
  2178. default=False,
  2179. metadata={"help": "Whether to return logits from the model."},
  2180. )
  2181. # custom train step args
  2182. custom_train_step_class: Optional[Any] = field(
  2183. default=None,
  2184. metadata={"help": "Custom train step class."},
  2185. )
  2186. custom_train_step_kwargs: Optional[dict[str, Any]] = field(
  2187. default=None,
  2188. metadata={"help": "Custom train step kwargs."},
  2189. )
  2190. # custom model args
  2191. custom_model_provider_function: Optional[Callable] = field(
  2192. default=None,
  2193. metadata={"help": "Custom model provider function."},
  2194. )
  2195. custom_prepare_model_function: Optional[Callable] = field(
  2196. default=None,
  2197. metadata={"help": "Custom prepare model function."},
  2198. )
  2199. custom_megatron_datasets_provider_function: Optional[Callable] = field(
  2200. default=None,
  2201. metadata={"help": "Custom megatron train_valid_test datasets provider function."},
  2202. )
  2203. custom_get_batch_function: Optional[Callable] = field(
  2204. default=None,
  2205. metadata={"help": "Custom get batch function."},
  2206. )
  2207. custom_loss_function: Optional[Callable] = field(
  2208. default=None,
  2209. metadata={"help": "Custom loss function."},
  2210. )
  2211. # remaining args such as enabling Alibi/ROPE positional embeddings,
  2212. # wandb logging, Multi-Query Attention, etc.
  2213. other_megatron_args: Optional[dict[str, Any]] = field(
  2214. default=None,
  2215. metadata={"help": "Other Megatron-LM arguments. Please refer Megatron-LM"},
  2216. )
  2217. def __post_init__(self):
  2218. prefix = "MEGATRON_LM_"
  2219. if self.tp_degree is None:
  2220. self.tp_degree = int(os.environ.get(prefix + "TP_DEGREE", 1))
  2221. if self.pp_degree is None:
  2222. self.pp_degree = int(os.environ.get(prefix + "PP_DEGREE", 1))
  2223. if self.num_micro_batches is None:
  2224. self.num_micro_batches = int(os.environ.get(prefix + "NUM_MICRO_BATCHES", 1))
  2225. if self.gradient_clipping is None:
  2226. self.gradient_clipping = float(os.environ.get(prefix + "GRADIENT_CLIPPING", 1.0))
  2227. if self.recompute_activations is None:
  2228. self.recompute_activations = str_to_bool(os.environ.get(prefix + "RECOMPUTE_ACTIVATIONS", "False")) == 1
  2229. if self.use_distributed_optimizer is None:
  2230. self.use_distributed_optimizer = (
  2231. str_to_bool(os.environ.get(prefix + "USE_DISTRIBUTED_OPTIMIZER", "False")) == 1
  2232. )
  2233. if self.sequence_parallelism is None:
  2234. self.sequence_parallelism = str_to_bool(os.environ.get(prefix + "SEQUENCE_PARALLELISM", "False")) == 1
  2235. if self.pp_degree > 1 or self.use_distributed_optimizer:
  2236. self.DDP_impl = "local"
  2237. else:
  2238. self.DDP_impl = "torch"
  2239. if self.consumed_samples is not None:
  2240. if len(self.consumed_samples) == 1:
  2241. self.consumed_samples.extend([0, 0])
  2242. elif len(self.consumed_samples) == 2:
  2243. self.consumed_samples.append(0)
  2244. self.megatron_lm_default_args = {
  2245. "tensor_model_parallel_size": self.tp_degree,
  2246. "pipeline_model_parallel_size": self.pp_degree,
  2247. "pipeline_model_parallel_split_rank": self.pipeline_model_parallel_split_rank,
  2248. "num_layers_per_virtual_pipeline_stage": self.num_layers_per_virtual_pipeline_stage,
  2249. "DDP_impl": self.DDP_impl,
  2250. "use_distributed_optimizer": self.use_distributed_optimizer,
  2251. "sequence_parallel": self.sequence_parallelism,
  2252. "clip_grad": self.gradient_clipping,
  2253. "num_micro_batches": self.num_micro_batches,
  2254. "consumed_samples": self.consumed_samples,
  2255. "no_wd_decay_cond": self.no_wd_decay_cond,
  2256. "scale_lr_cond": self.scale_lr_cond,
  2257. "lr_mult": self.lr_mult,
  2258. "megatron_dataset_flag": self.megatron_dataset_flag,
  2259. "eval_iters": self.eval_iters,
  2260. "eval_interval": self.eval_interval,
  2261. }
  2262. if self.recompute_activations:
  2263. self.megatron_lm_default_args["recompute_granularity"] = "selective"
  2264. if self.tensorboard_dir is not None:
  2265. self.megatron_lm_default_args["tensorboard_dir"] = self.tensorboard_dir
  2266. if self.set_all_logging_options:
  2267. self.set_tensorboard_logging_options()
  2268. if self.other_megatron_args is not None:
  2269. self.megatron_lm_default_args.update(self.other_megatron_args)
  2270. def set_network_size_args(self, model, batch_data=None):
  2271. model_config_type = model.config.model_type.lower()
  2272. for model_type in MODEL_CONFIGS_TO_MEGATRON_PARSERS.keys():
  2273. if model_type in model_config_type:
  2274. MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type](self, model, batch_data)
  2275. return
  2276. raise ValueError(
  2277. f"Accelerate Megatron-LM integration not supports {model_config_type} model. "
  2278. "You can add your own model config parser."
  2279. )
  2280. def set_mixed_precision(self, mixed_precision):
  2281. if mixed_precision == "fp16":
  2282. self.megatron_lm_default_args["fp16"] = True
  2283. elif mixed_precision == "bf16":
  2284. self.megatron_lm_default_args["bf16"] = True
  2285. self.DDP_impl = "local"
  2286. self.megatron_lm_default_args["DDP_impl"] = self.DDP_impl
  2287. def set_training_args(self, micro_batch_size, dp_degree):
  2288. self.data_parallel_size = dp_degree
  2289. self.micro_batch_size = micro_batch_size
  2290. self.global_batch_size = dp_degree * micro_batch_size * self.num_micro_batches
  2291. self.megatron_lm_default_args["data_parallel_size"] = self.data_parallel_size
  2292. self.megatron_lm_default_args["micro_batch_size"] = self.micro_batch_size
  2293. self.megatron_lm_default_args["global_batch_size"] = self.global_batch_size
  2294. def set_optimizer_type(self, optimizer):
  2295. optimizer_name = optimizer.__class__.__name__.lower()
  2296. if "adam" in optimizer_name:
  2297. self.megatron_lm_default_args["optimizer"] = "adam"
  2298. self.megatron_lm_default_args["adam_beta1"] = optimizer.defaults["betas"][0]
  2299. self.megatron_lm_default_args["adam_beta2"] = optimizer.defaults["betas"][1]
  2300. self.megatron_lm_default_args["adam_eps"] = optimizer.defaults["eps"]
  2301. elif "sgd" in optimizer_name:
  2302. self.megatron_lm_default_args["optimizer"] = "sgd"
  2303. self.megatron_lm_default_args["sgd_momentum"] = optimizer.defaults["momentum"]
  2304. else:
  2305. raise ValueError(f"Optimizer {optimizer_name} is not supported by Megatron-LM")
  2306. self.megatron_lm_default_args["lr"] = optimizer.defaults["lr"]
  2307. self.megatron_lm_default_args["weight_decay"] = optimizer.defaults["weight_decay"]
  2308. def set_scheduler_args(self, scheduler):
  2309. if self.train_iters is None:
  2310. self.train_iters = scheduler.total_num_steps // self.megatron_lm_default_args["data_parallel_size"]
  2311. if self.train_samples is not None:
  2312. self.train_samples = None
  2313. warnings.warn(
  2314. "Ignoring `train_samples` as `train_iters` based on scheduler is being used for training."
  2315. )
  2316. if self.lr_warmup_iters is None:
  2317. self.lr_warmup_iters = scheduler.warmup_num_steps // self.megatron_lm_default_args["data_parallel_size"]
  2318. if self.lr_warmup_samples is not None:
  2319. warnings.warn(
  2320. "Ignoring `lr_warmup_samples` as `lr_warmup_iters` based on scheduler is being used for training."
  2321. )
  2322. self.lr_warmup_samples = 0
  2323. self.megatron_lm_default_args["train_iters"] = self.train_iters
  2324. self.megatron_lm_default_args["lr_warmup_iters"] = self.lr_warmup_iters
  2325. self.megatron_lm_default_args["train_samples"] = self.train_samples
  2326. self.megatron_lm_default_args["lr_warmup_samples"] = self.lr_warmup_samples
  2327. self.megatron_lm_default_args["lr_decay_iters"] = self.lr_decay_iters
  2328. self.megatron_lm_default_args["lr_decay_samples"] = self.lr_decay_samples
  2329. self.megatron_lm_default_args["lr_warmup_fraction"] = self.lr_warmup_fraction
  2330. self.megatron_lm_default_args["lr_decay_style"] = self.lr_decay_style
  2331. self.megatron_lm_default_args["weight_decay_incr_style"] = self.weight_decay_incr_style
  2332. self.megatron_lm_default_args["start_weight_decay"] = self.start_weight_decay
  2333. self.megatron_lm_default_args["end_weight_decay"] = self.end_weight_decay
  2334. self.megatron_lm_default_args["min_lr"] = self.min_lr
  2335. def set_tensorboard_logging_options(self):
  2336. from megatron.training.arguments import _add_logging_args
  2337. parser = argparse.ArgumentParser()
  2338. parser = _add_logging_args(parser)
  2339. logging_args = parser.parse_known_args()
  2340. self.dataset_args = vars(logging_args[0])
  2341. for key, value in self.dataset_args.items():
  2342. if key.startswith("log_"):
  2343. self.megatron_lm_default_args[key] = True
  2344. elif key.startswith("no_log_"):
  2345. self.megatron_lm_default_args[key.replace("no_", "")] = True
  2346. MODEL_CONFIGS_TO_MEGATRON_PARSERS = {}
  2347. def add_model_config_to_megatron_parser(model_type: str):
  2348. def add_model_config_parser_helper(func):
  2349. @functools.wraps(func)
  2350. def wrapper(*args, **kwargs):
  2351. return func(*args, **kwargs)
  2352. MODEL_CONFIGS_TO_MEGATRON_PARSERS[model_type] = func
  2353. return wrapper
  2354. return add_model_config_parser_helper
  2355. @add_model_config_to_megatron_parser("megatron-bert")
  2356. def parse_bert_config(megatron_lm_plugin, model, batch_data):
  2357. model_type_name = "bert"
  2358. num_layers = model.config.num_hidden_layers
  2359. hidden_size = model.config.hidden_size
  2360. num_attention_heads = model.config.num_attention_heads
  2361. max_position_embeddings = model.config.max_position_embeddings
  2362. num_labels = model.config.num_labels
  2363. orig_vocab_size = model.config.vocab_size
  2364. pretraining_flag = False
  2365. if "maskedlm" in model.__class__.__name__.lower():
  2366. pretraining_flag = True
  2367. if megatron_lm_plugin.seq_length is not None:
  2368. if megatron_lm_plugin.encoder_seq_length is not None:
  2369. warnings.warn("Both `seq_length` and `encoder_seq_length` are set. Using `encoder_seq_length`.")
  2370. megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length
  2371. elif megatron_lm_plugin.encoder_seq_length is not None:
  2372. megatron_lm_plugin.seq_length = megatron_lm_plugin.encoder_seq_length
  2373. elif batch_data is not None:
  2374. megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1]
  2375. else:
  2376. megatron_lm_plugin.seq_length = max_position_embeddings
  2377. megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length
  2378. megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name
  2379. megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers
  2380. megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size
  2381. megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads
  2382. megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings
  2383. megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag
  2384. megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size
  2385. megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict
  2386. megatron_lm_plugin.megatron_lm_default_args["num_labels"] = num_labels
  2387. @add_model_config_to_megatron_parser("gpt2")
  2388. def parse_gpt2_config(megatron_lm_plugin, model, batch_data):
  2389. model_type_name = "gpt"
  2390. num_layers = model.config.n_layer
  2391. hidden_size = model.config.n_embd
  2392. num_attention_heads = model.config.n_head
  2393. max_position_embeddings = model.config.n_positions
  2394. orig_vocab_size = model.config.vocab_size
  2395. pretraining_flag = True
  2396. if megatron_lm_plugin.seq_length is not None:
  2397. if megatron_lm_plugin.decoder_seq_length is not None:
  2398. warnings.warn("Both `seq_length` and `decoder_seq_length` are set. Using `decoder_seq_length`.")
  2399. megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length
  2400. elif megatron_lm_plugin.decoder_seq_length is not None:
  2401. megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length
  2402. elif batch_data is not None:
  2403. megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1]
  2404. else:
  2405. megatron_lm_plugin.seq_length = max_position_embeddings
  2406. megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length
  2407. megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits
  2408. megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "GPT2BPETokenizer"
  2409. megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name
  2410. megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers
  2411. megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size
  2412. megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads
  2413. megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings
  2414. megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag
  2415. megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size
  2416. megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict
  2417. @add_model_config_to_megatron_parser("t5")
  2418. def parse_t5_config(megatron_lm_plugin, model, batch_data):
  2419. model_type_name = "t5"
  2420. num_layers = model.config.num_layers
  2421. hidden_size = model.config.d_model
  2422. num_attention_heads = model.config.num_heads
  2423. max_position_embeddings = model.config.n_positions if hasattr(model.config, "n_positions") else 1024
  2424. orig_vocab_size = model.config.vocab_size
  2425. pretraining_flag = True
  2426. if megatron_lm_plugin.encoder_seq_length is None:
  2427. if batch_data is not None:
  2428. megatron_lm_plugin.encoder_seq_length = batch_data["input_ids"].shape[1]
  2429. else:
  2430. megatron_lm_plugin.encoder_seq_length = max_position_embeddings
  2431. if megatron_lm_plugin.decoder_seq_length is None:
  2432. if batch_data is not None:
  2433. megatron_lm_plugin.decoder_seq_length = batch_data["labels"].shape[1]
  2434. else:
  2435. megatron_lm_plugin.decoder_seq_length = max_position_embeddings
  2436. megatron_lm_plugin.megatron_lm_default_args["encoder_seq_length"] = megatron_lm_plugin.encoder_seq_length
  2437. megatron_lm_plugin.megatron_lm_default_args["decoder_seq_length"] = megatron_lm_plugin.decoder_seq_length
  2438. megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name
  2439. megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers
  2440. megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size
  2441. megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads
  2442. megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings
  2443. megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag
  2444. megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size
  2445. megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict
  2446. @add_model_config_to_megatron_parser("llama")
  2447. def parse_llama_config(megatron_lm_plugin, model, batch_data):
  2448. model_type_name = "gpt"
  2449. num_layers = model.config.num_hidden_layers
  2450. pretraining_flag = True
  2451. hidden_size = model.config.hidden_size
  2452. num_attention_heads = model.config.num_attention_heads
  2453. orig_vocab_size = model.config.vocab_size
  2454. max_position_embeddings = model.config.max_position_embeddings
  2455. seq_length = getattr(model.config, "max_sequence_length", None)
  2456. if megatron_lm_plugin.seq_length is None:
  2457. if seq_length is not None:
  2458. megatron_lm_plugin.seq_length = seq_length
  2459. elif megatron_lm_plugin.decoder_seq_length is not None:
  2460. megatron_lm_plugin.seq_length = megatron_lm_plugin.decoder_seq_length
  2461. elif batch_data is not None:
  2462. megatron_lm_plugin.seq_length = batch_data["input_ids"].shape[1]
  2463. else:
  2464. megatron_lm_plugin.seq_length = max_position_embeddings
  2465. megatron_lm_plugin.megatron_lm_default_args["return_logits"] = megatron_lm_plugin.return_logits
  2466. megatron_lm_plugin.megatron_lm_default_args["tokenizer_type"] = "Llama2Tokenizer"
  2467. megatron_lm_plugin.megatron_lm_default_args["model_type_name"] = model_type_name
  2468. megatron_lm_plugin.megatron_lm_default_args["num_layers"] = num_layers
  2469. megatron_lm_plugin.megatron_lm_default_args["pretraining_flag"] = pretraining_flag
  2470. megatron_lm_plugin.megatron_lm_default_args["hidden_size"] = hidden_size
  2471. megatron_lm_plugin.megatron_lm_default_args["num_attention_heads"] = num_attention_heads
  2472. megatron_lm_plugin.megatron_lm_default_args["orig_vocab_size"] = orig_vocab_size
  2473. megatron_lm_plugin.megatron_lm_default_args["max_position_embeddings"] = max_position_embeddings
  2474. megatron_lm_plugin.megatron_lm_default_args["seq_length"] = megatron_lm_plugin.seq_length
  2475. megatron_lm_plugin.megatron_lm_default_args["model_return_dict"] = model.config.return_dict
  2476. @dataclass
  2477. class BnbQuantizationConfig:
  2478. """
  2479. A plugin to enable BitsAndBytes 4bit and 8bit quantization
  2480. Args:
  2481. load_in_8bit (`bool`, defaults to `False`):
  2482. Enable 8bit quantization.
  2483. llm_int8_threshold (`float`, defaults to `6.0`):
  2484. Value of the outliner threshold. Only relevant when `load_in_8bit=True`.
  2485. load_in_4bit (`bool`, defaults to `False`):
  2486. Enable 4bit quantization.
  2487. bnb_4bit_quant_type (`str`, defaults to `fp4`):
  2488. Set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','np4'}.
  2489. bnb_4bit_use_double_quant (`bool`, defaults to `False`):
  2490. Enable nested quantization where the quantization constants from the first quantization are quantized
  2491. again.
  2492. bnb_4bit_compute_dtype (`bool`, defaults to `fp16`):
  2493. This sets the computational type which might be different than the input time. For example, inputs might be
  2494. fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}.
  2495. torch_dtype (`torch.dtype`, defaults to `None`):
  2496. This sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value
  2497. to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model.
  2498. skip_modules (`List[str]`, defaults to `None`):
  2499. An explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`.
  2500. keep_in_fp32_modules (`List`, defaults to `None`):
  2501. An explicit list of the modules that we don't quantize. We keep them in `torch.float32`.
  2502. """
  2503. load_in_8bit: bool = field(default=False, metadata={"help": "enable 8bit quantization."})
  2504. llm_int8_threshold: float = field(
  2505. default=6.0,
  2506. metadata={"help": "value of the outliner threshold. only relevant when load_in_8bit=True"},
  2507. )
  2508. load_in_4bit: bool = field(default=False, metadata={"help": "enable 4bit quantization."})
  2509. bnb_4bit_quant_type: str = field(
  2510. default="fp4",
  2511. metadata={
  2512. "help": "set the quantization data type in the `bnb.nn.Linear4Bit` layers. Options are {'fp4','nf4'}."
  2513. },
  2514. )
  2515. bnb_4bit_use_double_quant: bool = field(
  2516. default=False,
  2517. metadata={
  2518. "help": "enable nested quantization where the quantization constants from the first quantization are quantized again."
  2519. },
  2520. )
  2521. bnb_4bit_compute_dtype: str = field(
  2522. default="fp16",
  2523. metadata={
  2524. "help": "This sets the computational type which might be different than the input time. For example, inputs might be "
  2525. "fp32, but computation can be set to bf16 for speedups. Options are {'fp32','fp16','bf16'}."
  2526. },
  2527. )
  2528. torch_dtype: torch.dtype = field(
  2529. default=None,
  2530. metadata={
  2531. "help": "this sets the dtype of the remaining non quantized layers. `bitsandbytes` library suggests to set the value"
  2532. "to `torch.float16` for 8 bit model and use the same dtype as the compute dtype for 4 bit model "
  2533. },
  2534. )
  2535. skip_modules: list[str] = field(
  2536. default=None,
  2537. metadata={
  2538. "help": "an explicit list of the modules that we don't quantize. The dtype of these modules will be `torch_dtype`."
  2539. },
  2540. )
  2541. keep_in_fp32_modules: list[str] = field(
  2542. default=None,
  2543. metadata={"help": "an explicit list of the modules that we don't quantize. We keep them in `torch.float32`."},
  2544. )
  2545. def __post_init__(self):
  2546. """
  2547. Safety checker that arguments are correct - also replaces some NoneType arguments with their default values.
  2548. """
  2549. if not isinstance(self.load_in_8bit, bool):
  2550. raise ValueError("load_in_8bit must be a boolean")
  2551. if not isinstance(self.load_in_4bit, bool):
  2552. raise ValueError("load_in_4bit must be a boolean")
  2553. if self.load_in_4bit and self.load_in_8bit:
  2554. raise ValueError("load_in_4bit and load_in_8bit can't be both True")
  2555. if not self.load_in_4bit and not self.load_in_8bit:
  2556. raise ValueError("load_in_4bit and load_in_8bit can't be both False")
  2557. if not isinstance(self.llm_int8_threshold, (int, float)):
  2558. raise ValueError("llm_int8_threshold must be a float or an int")
  2559. if not isinstance(self.bnb_4bit_quant_type, str):
  2560. raise ValueError("bnb_4bit_quant_type must be a string")
  2561. elif self.bnb_4bit_quant_type not in ["fp4", "nf4"]:
  2562. raise ValueError(f"bnb_4bit_quant_type must be in ['fp4','nf4'] but found {self.bnb_4bit_quant_type}")
  2563. if not isinstance(self.bnb_4bit_use_double_quant, bool):
  2564. raise ValueError("bnb_4bit_use_double_quant must be a boolean")
  2565. if isinstance(self.bnb_4bit_compute_dtype, str):
  2566. if self.bnb_4bit_compute_dtype == "fp32":
  2567. self.bnb_4bit_compute_dtype = torch.float32
  2568. elif self.bnb_4bit_compute_dtype == "fp16":
  2569. self.bnb_4bit_compute_dtype = torch.float16
  2570. elif self.bnb_4bit_compute_dtype == "bf16":
  2571. self.bnb_4bit_compute_dtype = torch.bfloat16
  2572. else:
  2573. raise ValueError(
  2574. f"bnb_4bit_compute_dtype must be in ['fp32','fp16','bf16'] but found {self.bnb_4bit_compute_dtype}"
  2575. )
  2576. elif not isinstance(self.bnb_4bit_compute_dtype, torch.dtype):
  2577. raise ValueError("bnb_4bit_compute_dtype must be a string or a torch.dtype")
  2578. if self.skip_modules is not None and not isinstance(self.skip_modules, list):
  2579. raise ValueError("skip_modules must be a list of strings")
  2580. if self.keep_in_fp32_modules is not None and not isinstance(self.keep_in_fp32_modules, list):
  2581. raise ValueError("keep_in_fp_32_modules must be a list of strings")
  2582. if self.load_in_4bit:
  2583. self.target_dtype = CustomDtype.INT4
  2584. if self.load_in_8bit:
  2585. self.target_dtype = torch.int8
  2586. if self.load_in_4bit and self.llm_int8_threshold != 6.0:
  2587. warnings.warn("llm_int8_threshold can only be used for model loaded in 8bit")
  2588. if isinstance(self.torch_dtype, str):
  2589. if self.torch_dtype == "fp32":
  2590. self.torch_dtype = torch.float32
  2591. elif self.torch_dtype == "fp16":
  2592. self.torch_dtype = torch.float16
  2593. elif self.torch_dtype == "bf16":
  2594. self.torch_dtype = torch.bfloat16
  2595. else:
  2596. raise ValueError(f"torch_dtype must be in ['fp32','fp16','bf16'] but found {self.torch_dtype}")
  2597. if self.load_in_8bit and self.torch_dtype is None:
  2598. self.torch_dtype = torch.float16
  2599. if self.load_in_4bit and self.torch_dtype is None:
  2600. self.torch_dtype = self.bnb_4bit_compute_dtype
  2601. if not isinstance(self.torch_dtype, torch.dtype):
  2602. raise ValueError("torch_dtype must be a torch.dtype")
  2603. def get_module_class_from_name(module, name):
  2604. """
  2605. Gets a class from a module by its name.
  2606. Args:
  2607. module (`torch.nn.Module`): The module to get the class from.
  2608. name (`str`): The name of the class.
  2609. """
  2610. modules_children = list(module.children())
  2611. if module.__class__.__name__ == name:
  2612. return module.__class__
  2613. elif len(modules_children) == 0:
  2614. return
  2615. else:
  2616. for child_module in modules_children:
  2617. module_class = get_module_class_from_name(child_module, name)
  2618. if module_class is not None:
  2619. return module_class