modeling.py 93 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import contextlib
  15. import gc
  16. import inspect
  17. import json
  18. import logging
  19. import os
  20. import re
  21. import shutil
  22. import tempfile
  23. import warnings
  24. from collections import OrderedDict, defaultdict
  25. from typing import Optional, Union
  26. import torch
  27. from torch import distributed as dist
  28. from torch import nn
  29. from ..state import AcceleratorState
  30. from .constants import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
  31. from .dataclasses import AutocastKwargs, CustomDtype, DistributedType
  32. from .imports import (
  33. is_hpu_available,
  34. is_mlu_available,
  35. is_mps_available,
  36. is_musa_available,
  37. is_npu_available,
  38. is_peft_available,
  39. is_sdaa_available,
  40. is_torch_xla_available,
  41. is_xpu_available,
  42. )
  43. from .memory import clear_device_cache, get_xpu_available_memory
  44. from .offload import load_offloaded_weight, offload_weight, save_offload_index
  45. from .tqdm import is_tqdm_available, tqdm
  46. from .versions import is_torch_version
  47. if is_npu_available(check_device=False):
  48. import torch_npu # noqa: F401
  49. if is_mlu_available(check_device=False):
  50. import torch_mlu # noqa: F401
  51. if is_sdaa_available(check_device=False):
  52. import torch_sdaa # noqa: F401
  53. if is_musa_available(check_device=False):
  54. import torch_musa # noqa: F401
  55. from safetensors import safe_open
  56. from safetensors.torch import load_file as safe_load_file
  57. WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
  58. logger = logging.getLogger(__name__)
  59. def is_peft_model(model):
  60. from .other import extract_model_from_parallel
  61. if is_peft_available():
  62. from peft import PeftModel
  63. return is_peft_available() and isinstance(extract_model_from_parallel(model), PeftModel)
  64. def check_device_same(first_device, second_device):
  65. """
  66. Utility method to check if two `torch` devices are similar. When dealing with CUDA devices, torch throws `False`
  67. for `torch.device("cuda") == torch.device("cuda:0")` whereas they should be the same
  68. Args:
  69. first_device (`torch.device`):
  70. First device to check
  71. second_device (`torch.device`):
  72. Second device to check
  73. """
  74. if first_device.type != second_device.type:
  75. return False
  76. if first_device.type != "cpu" and first_device.index is None:
  77. # In case the first_device is a cuda device and have
  78. # the index attribute set to `None`, default it to `0`
  79. first_device = torch.device(first_device.type, index=0)
  80. if second_device.type != "cpu" and second_device.index is None:
  81. # In case the second_device is a cuda device and have
  82. # the index attribute set to `None`, default it to `0`
  83. second_device = torch.device(second_device.type, index=0)
  84. return first_device == second_device
  85. def convert_file_size_to_int(size: Union[int, str]):
  86. """
  87. Converts a size expressed as a string with digits an unit (like `"5MB"`) to an integer (in bytes).
  88. Args:
  89. size (`int` or `str`): The size to convert. Will be directly returned if an `int`.
  90. Example:
  91. ```py
  92. >>> convert_file_size_to_int("1MiB")
  93. 1048576
  94. ```
  95. """
  96. mem_size = -1
  97. err_msg = (
  98. f"`size` {size} is not in a valid format. Use an integer for bytes, or a string with an unit (like '5.0GB')."
  99. )
  100. try:
  101. if isinstance(size, int):
  102. mem_size = size
  103. elif size.upper().endswith("GIB"):
  104. mem_size = int(float(size[:-3]) * (2**30))
  105. elif size.upper().endswith("MIB"):
  106. mem_size = int(float(size[:-3]) * (2**20))
  107. elif size.upper().endswith("KIB"):
  108. mem_size = int(float(size[:-3]) * (2**10))
  109. elif size.upper().endswith("GB"):
  110. int_size = int(float(size[:-2]) * (10**9))
  111. mem_size = int_size // 8 if size.endswith("b") else int_size
  112. elif size.upper().endswith("MB"):
  113. int_size = int(float(size[:-2]) * (10**6))
  114. mem_size = int_size // 8 if size.endswith("b") else int_size
  115. elif size.upper().endswith("KB"):
  116. int_size = int(float(size[:-2]) * (10**3))
  117. mem_size = int_size // 8 if size.endswith("b") else int_size
  118. except ValueError:
  119. raise ValueError(err_msg)
  120. if mem_size < 0:
  121. raise ValueError(err_msg)
  122. return mem_size
  123. def dtype_byte_size(dtype: torch.dtype):
  124. """
  125. Returns the size (in bytes) occupied by one parameter of type `dtype`.
  126. Example:
  127. ```py
  128. >>> dtype_byte_size(torch.float32)
  129. 4
  130. ```
  131. """
  132. if dtype == torch.bool:
  133. return 1 / 8
  134. elif dtype == CustomDtype.INT2:
  135. return 1 / 4
  136. elif dtype == CustomDtype.INT4:
  137. return 1 / 2
  138. elif dtype == CustomDtype.FP8:
  139. return 1
  140. elif is_torch_version(">=", "2.1.0") and dtype in [torch.float8_e4m3fn, torch.float8_e5m2]:
  141. return 1
  142. bit_search = re.search(r"[^\d](\d+)$", str(dtype))
  143. if bit_search is None:
  144. raise ValueError(f"`dtype` is not a valid dtype: {dtype}.")
  145. bit_size = int(bit_search.groups()[0])
  146. return bit_size // 8
  147. def id_tensor_storage(tensor: torch.Tensor) -> tuple[torch.device, int, int]:
  148. """
  149. Unique identifier to a tensor storage. Multiple different tensors can share the same underlying storage. For
  150. example, "meta" tensors all share the same storage, and thus their identifier will all be equal. This identifier is
  151. guaranteed to be unique and constant for this tensor's storage during its lifetime. Two tensor storages with
  152. non-overlapping lifetimes may have the same id.
  153. """
  154. _SIZE = {
  155. torch.int64: 8,
  156. torch.float32: 4,
  157. torch.int32: 4,
  158. torch.bfloat16: 2,
  159. torch.float16: 2,
  160. torch.int16: 2,
  161. torch.uint8: 1,
  162. torch.int8: 1,
  163. torch.bool: 1,
  164. torch.float64: 8,
  165. }
  166. try:
  167. storage_ptr = tensor.untyped_storage().data_ptr()
  168. storage_size = tensor.untyped_storage().nbytes()
  169. except Exception:
  170. try:
  171. # Fallback for torch==1.10
  172. storage_ptr = tensor.storage().data_ptr()
  173. storage_size = tensor.storage().size() * _SIZE[tensor.dtype]
  174. except NotImplementedError:
  175. # Fallback for meta storage
  176. storage_ptr = 0
  177. # On torch >=2.0 this is the tensor size
  178. storage_size = tensor.nelement() * _SIZE[tensor.dtype]
  179. return tensor.device, storage_ptr, storage_size
  180. def set_module_tensor_to_device(
  181. module: nn.Module,
  182. tensor_name: str,
  183. device: Union[int, str, torch.device],
  184. value: Optional[torch.Tensor] = None,
  185. dtype: Optional[Union[str, torch.dtype]] = None,
  186. fp16_statistics: Optional[torch.HalfTensor] = None,
  187. tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
  188. non_blocking: bool = False,
  189. clear_cache: bool = True,
  190. ):
  191. """
  192. A helper function to set a given tensor (parameter of buffer) of a module on a specific device (note that doing
  193. `param.to(device)` creates a new tensor not linked to the parameter, which is why we need this function).
  194. Args:
  195. module (`torch.nn.Module`):
  196. The module in which the tensor we want to move lives.
  197. tensor_name (`str`):
  198. The full name of the parameter/buffer.
  199. device (`int`, `str` or `torch.device`):
  200. The device on which to set the tensor.
  201. value (`torch.Tensor`, *optional*):
  202. The value of the tensor (useful when going from the meta device to any other device).
  203. dtype (`torch.dtype`, *optional*):
  204. If passed along the value of the parameter will be cast to this `dtype`. Otherwise, `value` will be cast to
  205. the dtype of the existing parameter in the model.
  206. fp16_statistics (`torch.HalfTensor`, *optional*):
  207. The list of fp16 statistics to set on the module, used for 8 bit model serialization.
  208. tied_params_map (Dict[int, Dict[torch.device, torch.Tensor]], *optional*, defaults to `None`):
  209. A map of current data pointers to dictionaries of devices to already dispatched tied weights. For a given
  210. execution device, this parameter is useful to reuse the first available pointer of a shared weight on the
  211. device for all others, instead of duplicating memory.
  212. non_blocking (`bool`, *optional*, defaults to `False`):
  213. If `True`, the device transfer will be asynchronous with respect to the host, if possible.
  214. clear_cache (`bool`, *optional*, defaults to `True`):
  215. Whether or not to clear the device cache after setting the tensor on the device.
  216. """
  217. # Recurse if needed
  218. if "." in tensor_name:
  219. splits = tensor_name.split(".")
  220. for split in splits[:-1]:
  221. new_module = getattr(module, split)
  222. if new_module is None:
  223. raise ValueError(f"{module} has no attribute {split}.")
  224. module = new_module
  225. tensor_name = splits[-1]
  226. if tensor_name not in module._parameters and tensor_name not in module._buffers:
  227. raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
  228. is_buffer = tensor_name in module._buffers
  229. old_value = getattr(module, tensor_name)
  230. # Treat the case where old_value (or a custom `value`, typically offloaded to RAM/disk) belongs to a tied group, and one of the weight
  231. # in the tied group has already been dispatched to the device, by avoiding reallocating memory on the device and just copying the pointer.
  232. if (
  233. value is not None
  234. and tied_params_map is not None
  235. and value.data_ptr() in tied_params_map
  236. and device in tied_params_map[value.data_ptr()]
  237. ):
  238. module._parameters[tensor_name] = tied_params_map[value.data_ptr()][device]
  239. return
  240. elif (
  241. tied_params_map is not None
  242. and old_value.data_ptr() in tied_params_map
  243. and device in tied_params_map[old_value.data_ptr()]
  244. ):
  245. module._parameters[tensor_name] = tied_params_map[old_value.data_ptr()][device]
  246. return
  247. if old_value.device == torch.device("meta") and device not in ["meta", torch.device("meta")] and value is None:
  248. raise ValueError(f"{tensor_name} is on the meta device, we need a `value` to put in on {device}.")
  249. param = module._parameters[tensor_name] if tensor_name in module._parameters else None
  250. param_cls = type(param)
  251. if value is not None:
  252. # We can expect mismatches when using bnb 4bit since Params4bit will reshape and pack the weights.
  253. # In other cases, we want to make sure we're not loading checkpoints that do not match the config.
  254. if old_value.shape != value.shape and param_cls.__name__ != "Params4bit":
  255. raise ValueError(
  256. f'Trying to set a tensor of shape {value.shape} in "{tensor_name}" (which has shape {old_value.shape}), this looks incorrect.'
  257. )
  258. if dtype is None:
  259. # For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model
  260. value = value.to(old_value.dtype, non_blocking=non_blocking)
  261. elif not str(value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
  262. value = value.to(dtype, non_blocking=non_blocking)
  263. device_quantization = None
  264. with torch.no_grad():
  265. # leave it on cpu first before moving them to cuda
  266. # # fix the case where the device is meta, we don't want to put it on cpu because there is no data =0
  267. if (
  268. param is not None
  269. and param.device.type not in ("cuda", "xpu")
  270. and torch.device(device).type in ("cuda", "xpu")
  271. and param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]
  272. ):
  273. device_quantization = device
  274. device = "cpu"
  275. # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
  276. if isinstance(device, int):
  277. if is_npu_available():
  278. device = f"npu:{device}"
  279. elif is_mlu_available():
  280. device = f"mlu:{device}"
  281. elif is_sdaa_available():
  282. device = f"sdaa:{device}"
  283. elif is_musa_available():
  284. device = f"musa:{device}"
  285. elif is_hpu_available():
  286. device = "hpu"
  287. if "xpu" in str(device) and not is_xpu_available():
  288. raise ValueError(f'{device} is not available, you should use device="cpu" instead')
  289. if value is None:
  290. new_value = old_value.to(device, non_blocking=non_blocking)
  291. if dtype is not None and device in ["meta", torch.device("meta")]:
  292. if not str(old_value.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
  293. new_value = new_value.to(dtype, non_blocking=non_blocking)
  294. if not is_buffer:
  295. module._parameters[tensor_name] = param_cls(new_value, requires_grad=old_value.requires_grad)
  296. elif isinstance(value, torch.Tensor):
  297. new_value = value.to(device, non_blocking=non_blocking)
  298. else:
  299. new_value = torch.tensor(value, device=device)
  300. if device_quantization is not None:
  301. device = device_quantization
  302. if is_buffer:
  303. module._buffers[tensor_name] = new_value
  304. elif value is not None or not check_device_same(torch.device(device), module._parameters[tensor_name].device):
  305. param_cls = type(module._parameters[tensor_name])
  306. kwargs = module._parameters[tensor_name].__dict__
  307. if param_cls.__name__ in ["Int8Params", "FP4Params", "Params4bit"]:
  308. if param_cls.__name__ == "Int8Params" and new_value.dtype == torch.float32:
  309. # downcast to fp16 if any - needed for 8bit serialization
  310. new_value = new_value.to(torch.float16, non_blocking=non_blocking)
  311. # quantize module that are going to stay on the cpu so that we offload quantized weights
  312. if device == "cpu" and param_cls.__name__ == "Int8Params":
  313. new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(0).to("cpu")
  314. new_value.CB = new_value.CB.to("cpu")
  315. new_value.SCB = new_value.SCB.to("cpu")
  316. else:
  317. new_value = param_cls(new_value, requires_grad=old_value.requires_grad, **kwargs).to(
  318. device, non_blocking=non_blocking
  319. )
  320. elif param_cls.__name__ in ["QTensor", "QBitsTensor"]:
  321. new_value = torch.nn.Parameter(new_value, requires_grad=old_value.requires_grad).to(
  322. device, non_blocking=non_blocking
  323. )
  324. elif param_cls.__name__ in ["AffineQuantizedTensor"]:
  325. new_value = new_value.to(device, non_blocking=non_blocking)
  326. else:
  327. new_value = param_cls(new_value, requires_grad=old_value.requires_grad).to(
  328. device, non_blocking=non_blocking
  329. )
  330. module._parameters[tensor_name] = new_value
  331. if fp16_statistics is not None:
  332. module._parameters[tensor_name].SCB = fp16_statistics.to(device, non_blocking=non_blocking)
  333. del fp16_statistics
  334. # as we put the weight to meta, it doesn't have SCB attr anymore. make sure that it is not a meta weight
  335. if (
  336. module.__class__.__name__ == "Linear8bitLt"
  337. and getattr(module.weight, "SCB", None) is None
  338. and str(module.weight.device) != "meta"
  339. ):
  340. # quantize only if necessary
  341. device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
  342. if not getattr(module.weight, "SCB", None) and device_index is not None:
  343. if module.bias is not None and module.bias.device.type != "meta":
  344. # if a bias exists, we need to wait until the bias is set on the correct device
  345. module = module.cuda(device_index)
  346. elif module.bias is None:
  347. # if no bias exists, we can quantize right away
  348. module = module.cuda(device_index)
  349. elif (
  350. module.__class__.__name__ == "Linear4bit"
  351. and getattr(module.weight, "quant_state", None) is None
  352. and str(module.weight.device) != "meta"
  353. ):
  354. # quantize only if necessary
  355. device_index = torch.device(device).index if torch.device(device).type == "cuda" else None
  356. if not getattr(module.weight, "quant_state", None) and device_index is not None:
  357. module.weight = module.weight.cuda(device_index)
  358. # clean pre and post forward hook
  359. if clear_cache and device not in ("cpu", "meta"):
  360. clear_device_cache()
  361. # When handling tied weights, we update tied_params_map to keep track of the tied weights that have already been allocated on the device in
  362. # order to avoid duplicating memory, see above.
  363. if (
  364. tied_params_map is not None
  365. and old_value.data_ptr() in tied_params_map
  366. and device not in tied_params_map[old_value.data_ptr()]
  367. ):
  368. tied_params_map[old_value.data_ptr()][device] = new_value
  369. elif (
  370. value is not None
  371. and tied_params_map is not None
  372. and value.data_ptr() in tied_params_map
  373. and device not in tied_params_map[value.data_ptr()]
  374. ):
  375. tied_params_map[value.data_ptr()][device] = new_value
  376. def named_module_tensors(
  377. module: nn.Module, include_buffers: bool = True, recurse: bool = False, remove_non_persistent: bool = False
  378. ):
  379. """
  380. A helper function that gathers all the tensors (parameters + buffers) of a given module. If `include_buffers=True`
  381. it's the same as doing `module.named_parameters(recurse=recurse) + module.named_buffers(recurse=recurse)`.
  382. Args:
  383. module (`torch.nn.Module`):
  384. The module we want the tensors on.
  385. include_buffer (`bool`, *optional*, defaults to `True`):
  386. Whether or not to include the buffers in the result.
  387. recurse (`bool`, *optional`, defaults to `False`):
  388. Whether or not to go look in every submodule or just return the direct parameters and buffers.
  389. remove_non_persistent (`bool`, *optional*, defaults to `False`):
  390. Whether or not to remove the non persistent buffer from the buffers. Useful only when include_buffers =
  391. True
  392. """
  393. yield from module.named_parameters(recurse=recurse)
  394. if include_buffers:
  395. non_persistent_buffers = set()
  396. if remove_non_persistent:
  397. non_persistent_buffers = get_non_persistent_buffers(module, recurse=recurse)
  398. for named_buffer in module.named_buffers(recurse=recurse):
  399. name, _ = named_buffer
  400. if name not in non_persistent_buffers:
  401. yield named_buffer
  402. def get_non_persistent_buffers(module: nn.Module, recurse: bool = False, fqns: bool = False):
  403. """
  404. Gather all non persistent buffers of a given modules into a set
  405. Args:
  406. module (`nn.Module`):
  407. The module we want the non persistent buffers on.
  408. recurse (`bool`, *optional*, defaults to `False`):
  409. Whether or not to go look in every submodule or just return the direct non persistent buffers.
  410. fqns (`bool`, *optional*, defaults to `False`):
  411. Whether or not to return the fully-qualified names of the non persistent buffers.
  412. """
  413. non_persistent_buffers_set = module._non_persistent_buffers_set
  414. if recurse:
  415. for n, m in module.named_modules():
  416. if fqns:
  417. non_persistent_buffers_set |= {n + "." + b for b in m._non_persistent_buffers_set}
  418. else:
  419. non_persistent_buffers_set |= m._non_persistent_buffers_set
  420. return non_persistent_buffers_set
  421. def check_tied_parameters_in_config(model: nn.Module):
  422. """
  423. Check if there is any indication in the given model that some weights should be tied.
  424. Args:
  425. model (`torch.nn.Module`): The model to inspect
  426. Returns:
  427. bool: True if the model needs to have tied weights
  428. """
  429. # based on model.tie_weights() method
  430. has_tied_word_embedding = False
  431. has_tied_encoder_decoder = False
  432. has_tied_module = False
  433. if "PreTrainedModel" in [c.__name__ for c in inspect.getmro(model.__class__)]:
  434. has_tied_word_embedding = False
  435. model_decoder_config = None
  436. if hasattr(model, "config"):
  437. model_decoder_config = (
  438. model.config.get_text_config(decoder=True)
  439. if hasattr(model.config, "get_text_config")
  440. else model.config
  441. )
  442. has_tied_word_embedding = (
  443. model_decoder_config is not None
  444. and getattr(model_decoder_config, "tie_word_embeddings", False)
  445. and model.get_output_embeddings()
  446. )
  447. has_tied_encoder_decoder = (
  448. hasattr(model, "config")
  449. and getattr(model.config, "is_encoder_decoder", False)
  450. and getattr(model.config, "tie_encoder_decoder", False)
  451. )
  452. has_tied_module = any(hasattr(module, "_tie_weights") for module in model.modules())
  453. return any([has_tied_word_embedding, has_tied_encoder_decoder, has_tied_module])
  454. def _get_param_device(param, device_map):
  455. if param in device_map:
  456. return device_map[param]
  457. parent_param = ".".join(param.split(".")[:-1])
  458. if parent_param == param:
  459. raise ValueError(f"The `device_map` does not contain the module {param}.")
  460. else:
  461. return _get_param_device(parent_param, device_map)
  462. def check_tied_parameters_on_same_device(tied_params, device_map):
  463. """
  464. Check if tied parameters are on the same device
  465. Args:
  466. tied_params (`List[List[str]]`):
  467. A list of lists of parameter names being all tied together.
  468. device_map (`Dict[str, Union[int, str, torch.device]]`):
  469. A map that specifies where each submodule should go.
  470. """
  471. for tie_param in tied_params:
  472. tie_param_devices = {}
  473. for param in tie_param:
  474. tie_param_devices[param] = _get_param_device(param, device_map)
  475. if len(set(tie_param_devices.values())) > 1:
  476. logger.warning(
  477. f"Tied parameters are on different devices: {tie_param_devices}. "
  478. "Please modify your custom device map or set `device_map='auto'`. "
  479. )
  480. def find_tied_parameters(model: torch.nn.Module, **kwargs) -> list[list[str]]:
  481. """
  482. Find the tied parameters in a given model.
  483. <Tip warning={true}>
  484. The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
  485. them.
  486. </Tip>
  487. Args:
  488. model (`torch.nn.Module`): The model to inspect.
  489. Returns:
  490. List[List[str]]: A list of lists of parameter names being all tied together.
  491. Example:
  492. ```py
  493. >>> from collections import OrderedDict
  494. >>> import torch.nn as nn
  495. >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
  496. >>> model.linear2.weight = model.linear1.weight
  497. >>> find_tied_parameters(model)
  498. [['linear1.weight', 'linear2.weight']]
  499. ```
  500. """
  501. # get ALL model parameters and their names
  502. all_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=False)}
  503. # get ONLY unique named parameters,
  504. # if parameter is tied and have multiple names, it will be included only once
  505. no_duplicate_named_parameters = {name: param for name, param in model.named_parameters(remove_duplicate=True)}
  506. # the difference of the two sets will give us the tied parameters
  507. tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
  508. # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
  509. # which names refer to the same parameter. To identify this, we need to group them together.
  510. tied_param_groups = {}
  511. for tied_param_name in tied_param_names:
  512. tied_param = all_named_parameters[tied_param_name]
  513. for param_name, param in no_duplicate_named_parameters.items():
  514. # compare if parameters are the same, if so, group their names together
  515. if param is tied_param:
  516. if param_name not in tied_param_groups:
  517. tied_param_groups[param_name] = []
  518. tied_param_groups[param_name].append(tied_param_name)
  519. return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]
  520. def retie_parameters(model, tied_params):
  521. """
  522. Reties tied parameters in a given model if the link was broken (for instance when adding hooks).
  523. Args:
  524. model (`torch.nn.Module`):
  525. The model in which to retie parameters.
  526. tied_params (`List[List[str]]`):
  527. A mapping parameter name to tied parameter name as obtained by `find_tied_parameters`.
  528. """
  529. for tied_group in tied_params:
  530. param_to_tie = None
  531. # two loops : the first one to set param_to_tie , the second one to change the values of tied_group
  532. for param_name in tied_group:
  533. module = model
  534. splits = param_name.split(".")
  535. for split in splits[:-1]:
  536. module = getattr(module, split)
  537. param = getattr(module, splits[-1])
  538. if param_to_tie is None and param.device != torch.device("meta"):
  539. param_to_tie = param
  540. break
  541. if param_to_tie is not None:
  542. for param_name in tied_group:
  543. module = model
  544. splits = param_name.split(".")
  545. for split in splits[:-1]:
  546. module = getattr(module, split)
  547. setattr(module, splits[-1], param_to_tie)
  548. def _get_proper_dtype(dtype: Union[str, torch.device]) -> torch.dtype:
  549. """
  550. Just does torch.dtype(dtype) if necessary.
  551. """
  552. if isinstance(dtype, str):
  553. # We accept "torch.float16" or just "float16"
  554. dtype = dtype.replace("torch.", "")
  555. dtype = getattr(torch, dtype)
  556. return dtype
  557. def compute_module_sizes(
  558. model: nn.Module,
  559. dtype: Optional[Union[str, torch.device]] = None,
  560. special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
  561. buffers_only: bool = False,
  562. ):
  563. """
  564. Compute the size of each submodule of a given model.
  565. """
  566. if dtype is not None:
  567. dtype = _get_proper_dtype(dtype)
  568. dtype_size = dtype_byte_size(dtype)
  569. if special_dtypes is not None:
  570. special_dtypes = {key: _get_proper_dtype(dtyp) for key, dtyp in special_dtypes.items()}
  571. special_dtypes_size = {key: dtype_byte_size(dtyp) for key, dtyp in special_dtypes.items()}
  572. module_sizes = defaultdict(int)
  573. module_list = []
  574. if not buffers_only:
  575. module_list = named_module_tensors(model, recurse=True)
  576. else:
  577. module_list = model.named_buffers(recurse=True)
  578. for name, tensor in module_list:
  579. if special_dtypes is not None and name in special_dtypes:
  580. size = tensor.numel() * special_dtypes_size[name]
  581. elif dtype is None:
  582. size = tensor.numel() * dtype_byte_size(tensor.dtype)
  583. elif str(tensor.dtype).startswith(("torch.uint", "torch.int", "torch.bool")):
  584. # According to the code in set_module_tensor_to_device, these types won't be converted
  585. # so use their original size here
  586. size = tensor.numel() * dtype_byte_size(tensor.dtype)
  587. else:
  588. size = tensor.numel() * min(dtype_size, dtype_byte_size(tensor.dtype))
  589. name_parts = name.split(".")
  590. for idx in range(len(name_parts) + 1):
  591. module_sizes[".".join(name_parts[:idx])] += size
  592. return module_sizes
  593. def compute_module_total_buffer_size(
  594. model: nn.Module,
  595. dtype: Optional[Union[str, torch.device]] = None,
  596. special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
  597. ):
  598. """
  599. Compute the total size of buffers in each submodule of a given model.
  600. """
  601. module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes, buffers_only=True)
  602. return module_sizes.get("", 0)
  603. def get_max_layer_size(
  604. modules: list[tuple[str, torch.nn.Module]], module_sizes: dict[str, int], no_split_module_classes: list[str]
  605. ):
  606. """
  607. Utility function that will scan a list of named modules and return the maximum size used by one full layer. The
  608. definition of a layer being:
  609. - a module with no direct children (just parameters and buffers)
  610. - a module whose class name is in the list `no_split_module_classes`
  611. Args:
  612. modules (`List[Tuple[str, torch.nn.Module]]`):
  613. The list of named modules where we want to determine the maximum layer size.
  614. module_sizes (`Dict[str, int]`):
  615. A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
  616. no_split_module_classes (`List[str]`):
  617. A list of class names for layers we don't want to be split.
  618. Returns:
  619. `Tuple[int, List[str]]`: The maximum size of a layer with the list of layer names realizing that maximum size.
  620. """
  621. max_size = 0
  622. layer_names = []
  623. modules_to_treat = modules.copy()
  624. while len(modules_to_treat) > 0:
  625. module_name, module = modules_to_treat.pop(0)
  626. modules_children = list(module.named_children()) if isinstance(module, torch.nn.Module) else []
  627. if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
  628. # No splitting this one so we compare to the max_size
  629. size = module_sizes[module_name]
  630. if size > max_size:
  631. max_size = size
  632. layer_names = [module_name]
  633. elif size == max_size:
  634. layer_names.append(module_name)
  635. else:
  636. modules_to_treat = [(f"{module_name}.{n}", v) for n, v in modules_children] + modules_to_treat
  637. return max_size, layer_names
  638. def get_max_memory(max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None):
  639. """
  640. Get the maximum memory available if nothing is passed, converts string to int otherwise.
  641. """
  642. import psutil
  643. if max_memory is None:
  644. max_memory = {}
  645. # Make sure CUDA is initialized on each GPU to have the right memory info.
  646. if is_npu_available():
  647. for i in range(torch.npu.device_count()):
  648. try:
  649. _ = torch.tensor(0, device=torch.device("npu", i))
  650. max_memory[i] = torch.npu.mem_get_info(i)[0]
  651. except Exception:
  652. logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
  653. continue
  654. elif is_mlu_available():
  655. for i in range(torch.mlu.device_count()):
  656. try:
  657. _ = torch.tensor(0, device=torch.device("mlu", i))
  658. max_memory[i] = torch.mlu.mem_get_info(i)[0]
  659. except Exception:
  660. logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
  661. continue
  662. elif is_sdaa_available():
  663. for i in range(torch.sdaa.device_count()):
  664. try:
  665. _ = torch.tensor(0, device=torch.device("sdaa", i))
  666. max_memory[i] = torch.sdaa.mem_get_info(i)[0]
  667. except Exception:
  668. logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
  669. continue
  670. elif is_musa_available():
  671. for i in range(torch.musa.device_count()):
  672. try:
  673. _ = torch.tensor(0, device=torch.device("musa", i))
  674. max_memory[i] = torch.musa.mem_get_info(i)[0]
  675. except Exception:
  676. logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
  677. continue
  678. elif is_xpu_available():
  679. for i in range(torch.xpu.device_count()):
  680. try:
  681. _ = torch.tensor(0, device=torch.device("xpu", i))
  682. max_memory[i] = get_xpu_available_memory(i)
  683. except Exception:
  684. logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
  685. continue
  686. elif is_hpu_available():
  687. for i in range(torch.hpu.device_count()):
  688. try:
  689. _ = torch.tensor(0, device=torch.device("hpu", i))
  690. max_memory[i] = torch.hpu.mem_get_info(i)[0]
  691. except Exception:
  692. logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
  693. continue
  694. else:
  695. for i in range(torch.cuda.device_count()):
  696. try:
  697. _ = torch.tensor([0], device=i)
  698. max_memory[i] = torch.cuda.mem_get_info(i)[0]
  699. except Exception:
  700. logger.info(f"Device {i} seems unavailable, Proceeding to check subsequent devices.")
  701. continue
  702. # allocate everything in the mps device as the RAM is shared
  703. if is_mps_available():
  704. max_memory["mps"] = psutil.virtual_memory().available
  705. else:
  706. max_memory["cpu"] = psutil.virtual_memory().available
  707. return max_memory
  708. for key in max_memory:
  709. if isinstance(max_memory[key], str):
  710. max_memory[key] = convert_file_size_to_int(max_memory[key])
  711. # Need to sort the device by type to make sure that we allocate the gpu first.
  712. # As gpu/npu/xpu are represented by int, we need to sort them first.
  713. gpu_devices = [k for k in max_memory.keys() if isinstance(k, int)]
  714. gpu_devices.sort()
  715. # check if gpu/npu/xpu devices are available and if not, throw a warning
  716. if is_npu_available():
  717. num_devices = torch.npu.device_count()
  718. elif is_mlu_available():
  719. num_devices = torch.mlu.device_count()
  720. elif is_sdaa_available():
  721. num_devices = torch.sdaa.device_count()
  722. elif is_musa_available():
  723. num_devices = torch.musa.device_count()
  724. elif is_xpu_available():
  725. num_devices = torch.xpu.device_count()
  726. elif is_hpu_available():
  727. num_devices = torch.hpu.device_count()
  728. else:
  729. num_devices = torch.cuda.device_count()
  730. for device in gpu_devices:
  731. if device >= num_devices or device < 0:
  732. logger.warning(f"Device {device} is not available, available devices are {list(range(num_devices))}")
  733. # Add the other devices in the preset order if they are available
  734. all_devices = gpu_devices + [k for k in ["mps", "cpu", "disk"] if k in max_memory.keys()]
  735. # Raise an error if a device is not recognized
  736. for k in max_memory.keys():
  737. if k not in all_devices:
  738. raise ValueError(
  739. f"Device {k} is not recognized, available devices are integers(for GPU/XPU), 'mps', 'cpu' and 'disk'"
  740. )
  741. max_memory = {k: max_memory[k] for k in all_devices}
  742. return max_memory
  743. def clean_device_map(device_map: dict[str, Union[int, str, torch.device]], module_name: str = ""):
  744. """
  745. Cleans a device_map by grouping all submodules that go on the same device together.
  746. """
  747. # Get the value of the current module and if there is only one split across several keys, regroup it.
  748. prefix = "" if module_name == "" else f"{module_name}."
  749. values = [v for k, v in device_map.items() if k.startswith(prefix)]
  750. if len(set(values)) == 1 and len(values) > 1:
  751. for k in [k for k in device_map if k.startswith(prefix)]:
  752. del device_map[k]
  753. device_map[module_name] = values[0]
  754. # Recurse over the children
  755. children_modules = [k for k in device_map.keys() if k.startswith(prefix) and len(k) > len(module_name)]
  756. idx = len(module_name.split(".")) + 1 if len(module_name) > 0 else 1
  757. children_modules = set(".".join(k.split(".")[:idx]) for k in children_modules)
  758. for child in children_modules:
  759. clean_device_map(device_map, module_name=child)
  760. return device_map
  761. def load_offloaded_weights(model, index, offload_folder):
  762. """
  763. Loads the weights from the offload folder into the model.
  764. Args:
  765. model (`torch.nn.Module`):
  766. The model to load the weights into.
  767. index (`dict`):
  768. A dictionary containing the parameter name and its metadata for each parameter that was offloaded from the
  769. model.
  770. offload_folder (`str`):
  771. The folder where the offloaded weights are stored.
  772. """
  773. if index is None or len(index) == 0:
  774. # Nothing to do
  775. return
  776. for param_name, metadata in index.items():
  777. if "SCB" in param_name:
  778. continue
  779. fp16_statistics = None
  780. if "weight" in param_name and param_name.replace("weight", "SCB") in index.keys():
  781. weight_name = param_name.replace("weight", "SCB")
  782. fp16_statistics = load_offloaded_weight(
  783. os.path.join(offload_folder, f"{weight_name}.dat"), index[weight_name]
  784. )
  785. tensor_file = os.path.join(offload_folder, f"{param_name}.dat")
  786. weight = load_offloaded_weight(tensor_file, metadata)
  787. set_module_tensor_to_device(model, param_name, "cpu", value=weight, fp16_statistics=fp16_statistics)
  788. def get_module_leaves(module_sizes):
  789. module_children = {}
  790. for module in module_sizes:
  791. if module == "" or "." not in module:
  792. continue
  793. parent = module.rsplit(".", 1)[0]
  794. module_children[parent] = module_children.get(parent, 0) + 1
  795. leaves = [module for module in module_sizes if module_children.get(module, 0) == 0 and module != ""]
  796. return leaves
  797. def get_balanced_memory(
  798. model: nn.Module,
  799. max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
  800. no_split_module_classes: Optional[list[str]] = None,
  801. dtype: Optional[Union[str, torch.dtype]] = None,
  802. special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
  803. low_zero: bool = False,
  804. ):
  805. """
  806. Compute a `max_memory` dictionary for [`infer_auto_device_map`] that will balance the use of each available GPU.
  807. <Tip>
  808. All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
  809. meta device (as it would if initialized within the `init_empty_weights` context manager).
  810. </Tip>
  811. Args:
  812. model (`torch.nn.Module`):
  813. The model to analyze.
  814. max_memory (`Dict`, *optional*):
  815. A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
  816. Example: `max_memory={0: "1GB"}`.
  817. no_split_module_classes (`List[str]`, *optional*):
  818. A list of layer class names that should never be split across device (for instance any layer that has a
  819. residual connection).
  820. dtype (`str` or `torch.dtype`, *optional*):
  821. If provided, the weights will be converted to that type when loaded.
  822. special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
  823. If provided, special dtypes to consider for some specific weights (will override dtype used as default for
  824. all weights).
  825. low_zero (`bool`, *optional*):
  826. Minimizes the number of weights on GPU 0, which is convenient when it's used for other operations (like the
  827. Transformers generate function).
  828. """
  829. # Get default / clean up max_memory
  830. user_not_set_max_memory = max_memory is None
  831. max_memory = get_max_memory(max_memory)
  832. if is_npu_available():
  833. expected_device_type = "npu"
  834. elif is_mlu_available():
  835. expected_device_type = "mlu"
  836. elif is_sdaa_available():
  837. expected_device_type = "sdaa"
  838. elif is_musa_available():
  839. expected_device_type = "musa"
  840. elif is_xpu_available():
  841. expected_device_type = "xpu"
  842. elif is_hpu_available():
  843. expected_device_type = "hpu"
  844. elif is_mps_available():
  845. expected_device_type = "mps"
  846. else:
  847. expected_device_type = "cuda"
  848. num_devices = len([d for d in max_memory if torch.device(d).type == expected_device_type and max_memory[d] > 0])
  849. if num_devices == 0:
  850. return max_memory
  851. if num_devices == 1:
  852. # We cannot do low_zero on just one GPU, but we will still reserve some memory for the buffer
  853. low_zero = False
  854. # If user just asked us to handle memory usage, we should avoid OOM
  855. if user_not_set_max_memory:
  856. for key in max_memory.keys():
  857. if isinstance(key, int):
  858. max_memory[key] *= 0.9 # 90% is a good compromise
  859. logger.info(
  860. f"We will use 90% of the memory on device {key} for storing the model, and 10% for the buffer to avoid OOM. "
  861. "You can set `max_memory` in to a higher value to use more memory (at your own risk)."
  862. )
  863. break # only one device
  864. module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
  865. per_gpu = module_sizes[""] // (num_devices - 1 if low_zero else num_devices)
  866. # We can't just set the memory to model_size // num_devices as it will end being too small: each GPU will get
  867. # slightly less layers and some layers will end up offload at the end. So this function computes a buffer size to
  868. # add which is the biggest of:
  869. # - the size of no split block (if applicable)
  870. # - the mean of the layer sizes
  871. if no_split_module_classes is None:
  872. no_split_module_classes = []
  873. elif not isinstance(no_split_module_classes, (list, tuple)):
  874. no_split_module_classes = [no_split_module_classes]
  875. # Identify the size of the no_split_block modules
  876. if len(no_split_module_classes) > 0:
  877. no_split_children = {}
  878. for name, size in module_sizes.items():
  879. if name == "":
  880. continue
  881. submodule = model
  882. for submodule_name in name.split("."):
  883. submodule = getattr(submodule, submodule_name)
  884. class_name = submodule.__class__.__name__
  885. if class_name in no_split_module_classes and class_name not in no_split_children:
  886. no_split_children[class_name] = size
  887. if set(no_split_children.keys()) == set(no_split_module_classes):
  888. break
  889. buffer = max(no_split_children.values()) if len(no_split_children) > 0 else 0
  890. else:
  891. buffer = 0
  892. # Compute mean of final modules. In the first dict of module sizes, leaves are the parameters
  893. leaves = get_module_leaves(module_sizes)
  894. leaves_set = set(leaves) # Convert to set for O(1) membership testing
  895. module_sizes = {n: v for n, v in module_sizes.items() if n not in leaves_set}
  896. # Once removed, leaves are the final modules.
  897. leaves = get_module_leaves(module_sizes)
  898. mean_leaves = int(sum([module_sizes[n] for n in leaves]) / max(len(leaves), 1))
  899. buffer = int(1.25 * max(buffer, mean_leaves))
  900. per_gpu += buffer
  901. # Sorted list of GPUs id (we may have some gpu ids not included in the our max_memory list - let's ignore them)
  902. gpus_idx_list = list(
  903. sorted(
  904. device_id for device_id, device_mem in max_memory.items() if isinstance(device_id, int) and device_mem > 0
  905. )
  906. )
  907. # The last device is left with max_memory just in case the buffer is not enough.
  908. for idx in gpus_idx_list[:-1]:
  909. max_memory[idx] = min(max_memory[0] if low_zero and idx == 0 else per_gpu, max_memory[idx])
  910. if low_zero:
  911. min_zero = max(0, module_sizes[""] - sum([max_memory[i] for i in range(1, num_devices)]))
  912. max_memory[0] = min(min_zero, max_memory[0])
  913. return max_memory
  914. def calculate_maximum_sizes(model: torch.nn.Module):
  915. "Computes the total size of the model and its largest layer"
  916. sizes = compute_module_sizes(model)
  917. # `transformers` models store this information for us
  918. no_split_modules = getattr(model, "_no_split_modules", None)
  919. if no_split_modules is None:
  920. no_split_modules = []
  921. modules_to_treat = (
  922. list(model.named_parameters(recurse=False))
  923. + list(model.named_children())
  924. + list(model.named_buffers(recurse=False))
  925. )
  926. largest_layer = get_max_layer_size(modules_to_treat, sizes, no_split_modules)
  927. total_size = sizes[""]
  928. return total_size, largest_layer
  929. def _init_infer_auto_device_map(
  930. model: nn.Module,
  931. max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
  932. no_split_module_classes: Optional[list[str]] = None,
  933. dtype: Optional[Union[str, torch.dtype]] = None,
  934. special_dtypes: Optional[dict[str, Union[str, torch.device]]] = None,
  935. ) -> tuple[
  936. list[Union[int, str]],
  937. dict[Union[int, str], Union[int, str]],
  938. list[Union[int, str]],
  939. list[int],
  940. dict[str, int],
  941. list[list[str]],
  942. list[str],
  943. list[tuple[str, nn.Module]],
  944. ]:
  945. """
  946. Initialize variables required for computing the device map for model allocation.
  947. """
  948. max_memory = get_max_memory(max_memory)
  949. if no_split_module_classes is None:
  950. no_split_module_classes = []
  951. elif not isinstance(no_split_module_classes, (list, tuple)):
  952. no_split_module_classes = [no_split_module_classes]
  953. devices = list(max_memory.keys())
  954. if "disk" not in devices:
  955. devices.append("disk")
  956. gpus = [device for device in devices if device not in ["cpu", "disk"]]
  957. # Devices that need to keep space for a potential offloaded layer.
  958. if "mps" in gpus:
  959. main_devices = ["mps"]
  960. elif len(gpus) > 0:
  961. main_devices = [gpus[0], "cpu"]
  962. else:
  963. main_devices = ["cpu"]
  964. module_sizes = compute_module_sizes(model, dtype=dtype, special_dtypes=special_dtypes)
  965. tied_parameters = find_tied_parameters(model)
  966. if check_tied_parameters_in_config(model) and len(tied_parameters) == 0:
  967. logger.warning(
  968. "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
  969. )
  970. # Direct submodules and parameters
  971. modules_to_treat = (
  972. list(model.named_parameters(recurse=False))
  973. + list(model.named_children())
  974. + list(model.named_buffers(recurse=False))
  975. )
  976. return (
  977. devices,
  978. max_memory,
  979. main_devices,
  980. gpus,
  981. module_sizes,
  982. tied_parameters,
  983. no_split_module_classes,
  984. modules_to_treat,
  985. )
  986. def get_module_size_with_ties(
  987. tied_params,
  988. module_size,
  989. module_sizes,
  990. modules_to_treat,
  991. ) -> tuple[int, list[str], list[nn.Module]]:
  992. """
  993. Calculate the total size of a module, including its tied parameters.
  994. Args:
  995. tied_params (`List[str]`): The list of tied parameters.
  996. module_size (`int`): The size of the module without tied parameters.
  997. module_sizes (`Dict[str, int]`): A dictionary mapping each layer name to its size.
  998. modules_to_treat (`List[Tuple[str, nn.Module]]`): The list of named modules to treat.
  999. Returns:
  1000. `Tuple[int, List[str], List[nn.Module]]`: The total size of the module, the names of the tied modules, and the
  1001. tied modules.
  1002. """
  1003. if len(tied_params) < 1:
  1004. return module_size, [], []
  1005. tied_module_names = []
  1006. tied_modules = []
  1007. for tied_param in tied_params:
  1008. tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if tied_param.startswith(n + ".")][0]
  1009. tied_module_names.append(modules_to_treat[tied_module_index][0])
  1010. tied_modules.append(modules_to_treat[tied_module_index][1])
  1011. module_size_with_ties = module_size
  1012. for tied_param, tied_module_name in zip(tied_params, tied_module_names):
  1013. module_size_with_ties += module_sizes[tied_module_name] - module_sizes[tied_param]
  1014. return module_size_with_ties, tied_module_names, tied_modules
  1015. def fallback_allocate(
  1016. modules: list[tuple[str, nn.Module]],
  1017. module_sizes: dict[str, int],
  1018. size_limit: Union[int, str],
  1019. no_split_module_classes: Optional[list[str]] = None,
  1020. tied_parameters: Optional[list[list[str]]] = None,
  1021. ) -> tuple[Optional[str], Optional[nn.Module], list[tuple[str, nn.Module]]]:
  1022. """
  1023. Find a module that fits in the size limit using BFS and return it with its name and the remaining modules.
  1024. Args:
  1025. modules (`List[Tuple[str, nn.Module]]`):
  1026. The list of named modules to search in.
  1027. module_sizes (`Dict[str, int]`):
  1028. A dictionary mapping each layer name to its size (as generated by `compute_module_sizes`).
  1029. size_limit (`Union[int, str]`):
  1030. The maximum size a module can have.
  1031. no_split_module_classes (`Optional[List[str]]`, *optional*):
  1032. A list of class names for layers we don't want to be split.
  1033. tied_parameters (`Optional[List[List[str]]`, *optional*):
  1034. A list of lists of parameter names being all tied together.
  1035. Returns:
  1036. `Tuple[Optional[str], Optional[nn.Module], List[Tuple[str, nn.Module]]]`: A tuple containing:
  1037. - The name of the module that fits within the size limit.
  1038. - The module itself.
  1039. - The list of remaining modules after the found module is removed.
  1040. """
  1041. try:
  1042. size_limit = convert_file_size_to_int(size_limit)
  1043. except ValueError:
  1044. return None, None, modules
  1045. if no_split_module_classes is None:
  1046. no_split_module_classes = []
  1047. if tied_parameters is None:
  1048. tied_parameters = []
  1049. modules_to_search = modules.copy()
  1050. module_found = False
  1051. while modules_to_search:
  1052. name, module = modules_to_search.pop(0)
  1053. tied_param_groups = [
  1054. tied_group
  1055. for tied_group in tied_parameters
  1056. if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
  1057. ]
  1058. tied_params = sum(
  1059. [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
  1060. )
  1061. module_size_with_ties, _, _ = get_module_size_with_ties(
  1062. tied_params, module_sizes[name], module_sizes, modules_to_search
  1063. )
  1064. # If the module fits in the size limit, we found it.
  1065. if module_size_with_ties <= size_limit:
  1066. module_found = True
  1067. break
  1068. # The module is too big, we need to split it if possible.
  1069. modules_children = (
  1070. []
  1071. if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
  1072. else list(module.named_children())
  1073. )
  1074. # Split fails, move to the next module
  1075. if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
  1076. continue
  1077. # split is possible, add the children to the list of modules to search
  1078. modules_children = list(module.named_parameters(recurse=False)) + modules_children
  1079. modules_to_search = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_search
  1080. if not module_found:
  1081. return None, None, modules
  1082. # Prepare the module list for removal of the found module
  1083. current_names = [n for n, _ in modules]
  1084. dot_idx = [i for i, c in enumerate(name) if c == "."]
  1085. for dot_index in dot_idx:
  1086. parent_name = name[:dot_index]
  1087. if parent_name in current_names:
  1088. parent_module_idx = current_names.index(parent_name)
  1089. _, parent_module = modules[parent_module_idx]
  1090. module_children = list(parent_module.named_parameters(recurse=False)) + list(
  1091. parent_module.named_children()
  1092. )
  1093. modules = (
  1094. modules[:parent_module_idx]
  1095. + [(f"{parent_name}.{n}", v) for n, v in module_children]
  1096. + modules[parent_module_idx + 1 :]
  1097. )
  1098. current_names = [n for n, _ in modules]
  1099. # Now the target module should be directly in the list
  1100. target_idx = current_names.index(name)
  1101. name, module = modules.pop(target_idx)
  1102. return name, module, modules
  1103. def infer_auto_device_map(
  1104. model: nn.Module,
  1105. max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
  1106. no_split_module_classes: Optional[list[str]] = None,
  1107. dtype: Optional[Union[str, torch.dtype]] = None,
  1108. special_dtypes: Optional[dict[str, Union[str, torch.dtype]]] = None,
  1109. verbose: bool = False,
  1110. clean_result: bool = True,
  1111. offload_buffers: bool = False,
  1112. fallback_allocation: bool = False,
  1113. ):
  1114. """
  1115. Compute a device map for a given model giving priority to GPUs, then offload on CPU and finally offload to disk,
  1116. such that:
  1117. - we don't exceed the memory available of any of the GPU.
  1118. - if offload to the CPU is needed, there is always room left on GPU 0 to put back the layer offloaded on CPU that
  1119. has the largest size.
  1120. - if offload to the CPU is needed,we don't exceed the RAM available on the CPU.
  1121. - if offload to the disk is needed, there is always room left on the CPU to put back the layer offloaded on disk
  1122. that has the largest size.
  1123. <Tip>
  1124. All computation is done analyzing sizes and dtypes of the model parameters. As a result, the model can be on the
  1125. meta device (as it would if initialized within the `init_empty_weights` context manager).
  1126. </Tip>
  1127. Args:
  1128. model (`torch.nn.Module`):
  1129. The model to analyze.
  1130. max_memory (`Dict`, *optional*):
  1131. A dictionary device identifier to maximum memory. Will default to the maximum memory available if unset.
  1132. Example: `max_memory={0: "1GB"}`.
  1133. no_split_module_classes (`List[str]`, *optional*):
  1134. A list of layer class names that should never be split across device (for instance any layer that has a
  1135. residual connection).
  1136. dtype (`str` or `torch.dtype`, *optional*):
  1137. If provided, the weights will be converted to that type when loaded.
  1138. special_dtypes (`Dict[str, Union[str, torch.device]]`, *optional*):
  1139. If provided, special dtypes to consider for some specific weights (will override dtype used as default for
  1140. all weights).
  1141. verbose (`bool`, *optional*, defaults to `False`):
  1142. Whether or not to provide debugging statements as the function builds the device_map.
  1143. clean_result (`bool`, *optional*, defaults to `True`):
  1144. Clean the resulting device_map by grouping all submodules that go on the same device together.
  1145. offload_buffers (`bool`, *optional*, defaults to `False`):
  1146. In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
  1147. well as the parameters.
  1148. fallback_allocation (`bool`, *optional*, defaults to `False`):
  1149. When regular allocation fails, try to allocate a module that fits in the size limit using BFS.
  1150. """
  1151. # Initialize the variables
  1152. (
  1153. devices,
  1154. max_memory,
  1155. main_devices,
  1156. gpus,
  1157. module_sizes,
  1158. tied_parameters,
  1159. no_split_module_classes,
  1160. modules_to_treat,
  1161. ) = _init_infer_auto_device_map(model, max_memory, no_split_module_classes, dtype, special_dtypes)
  1162. device_map = OrderedDict()
  1163. current_device = 0
  1164. device_memory_used = {device: 0 for device in devices}
  1165. device_buffer_sizes = {}
  1166. device_minimum_assignment_memory = {}
  1167. # Initialize maximum largest layer, to know which space to keep in memory
  1168. max_layer_size, max_layer_names = get_max_layer_size(modules_to_treat, module_sizes, no_split_module_classes)
  1169. # Ready ? This is going to be a bit messy.
  1170. while len(modules_to_treat) > 0:
  1171. name, module = modules_to_treat.pop(0)
  1172. if verbose:
  1173. print(f"\nTreating module {name}.")
  1174. # Max size in the remaining layers may have changed since we took one, so we maybe update it.
  1175. max_layer_names = [n for n in max_layer_names if n != name and not n.startswith(name + ".")]
  1176. if len(max_layer_names) == 0:
  1177. max_layer_size, max_layer_names = get_max_layer_size(
  1178. [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
  1179. module_sizes,
  1180. no_split_module_classes,
  1181. )
  1182. # Assess size needed
  1183. module_size = module_sizes[name]
  1184. # We keep relevant tied parameters only: one of the tied parameters in the group is inside the current module
  1185. # and the other is not.
  1186. # Note: If we are currently processing the name `compute.weight`, an other parameter named
  1187. # e.g. `compute.weight_submodule.parameter`
  1188. # needs to be considered outside the current module, hence the check with additional dots.
  1189. tied_param_groups = [
  1190. tied_group
  1191. for tied_group in tied_parameters
  1192. if any(name + "." in k + "." for k in tied_group) and not all(name + "." in k + "." for k in tied_group)
  1193. ]
  1194. if verbose and len(tied_param_groups) > 0:
  1195. print(f" Found the relevant tied param groups {tied_param_groups}")
  1196. # Then we keep track of all the parameters that are tied to the current module, but not in the current module
  1197. tied_params = sum(
  1198. [[p for p in tied_group if name + "." not in p + "."] for tied_group in tied_param_groups], []
  1199. )
  1200. if verbose and len(tied_params) > 0:
  1201. print(f" So those parameters need to be taken into account {tied_params}")
  1202. device = devices[current_device]
  1203. current_max_size = max_memory[device] if device != "disk" else None
  1204. current_memory_reserved = 0
  1205. # Reduce max size available by the largest layer.
  1206. if devices[current_device] in main_devices:
  1207. current_max_size = current_max_size - max_layer_size
  1208. current_memory_reserved = max_layer_size
  1209. module_size_with_ties, tied_module_names, tied_modules = get_module_size_with_ties(
  1210. tied_params, module_size, module_sizes, modules_to_treat
  1211. )
  1212. # The module and its tied modules fit on the current device.
  1213. if current_max_size is None or device_memory_used[device] + module_size_with_ties <= current_max_size:
  1214. if verbose:
  1215. output = f"Putting {name}"
  1216. if tied_module_names:
  1217. output += f" and {tied_module_names}"
  1218. else:
  1219. output += f" (size={module_size})"
  1220. if current_max_size is not None:
  1221. output += f" (available={current_max_size - device_memory_used[device]})"
  1222. output += f" on {device}."
  1223. print(output)
  1224. device_memory_used[device] += module_size_with_ties
  1225. # Assign the primary module to the device.
  1226. device_map[name] = device
  1227. # Assign tied modules if any.
  1228. for tied_module_name in tied_module_names:
  1229. if tied_module_name in [m[0] for m in modules_to_treat]:
  1230. # Find the index of the tied module in the list
  1231. tied_module_index = next(i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name)
  1232. # Remove the tied module from the list to prevent reprocessing
  1233. modules_to_treat.pop(tied_module_index)
  1234. # Assign the tied module to the device
  1235. device_map[tied_module_name] = device
  1236. # Buffer Handling
  1237. if not offload_buffers and isinstance(module, nn.Module):
  1238. # Compute the total buffer size for the module
  1239. current_buffer_size = compute_module_total_buffer_size(
  1240. module, dtype=dtype, special_dtypes=special_dtypes
  1241. )
  1242. # Update the buffer size on the device
  1243. device_buffer_sizes[device] = device_buffer_sizes.get(device, 0) + current_buffer_size
  1244. continue
  1245. # The current module itself fits, so we try to split the tied modules.
  1246. if len(tied_params) > 0 and device_memory_used[device] + module_size <= current_max_size:
  1247. # can we split one of the tied modules to make it smaller or do we need to go on the next device?
  1248. if verbose:
  1249. print(
  1250. f"Not enough space on {devices[current_device]} to put {name} and {tied_module_names} (space "
  1251. f"available {current_max_size - device_memory_used[device]}, needed size {module_size_with_ties})."
  1252. )
  1253. split_happened = False
  1254. for tied_module_name, tied_module in zip(tied_module_names, tied_modules):
  1255. tied_module_children = list(tied_module.named_children())
  1256. if len(tied_module_children) == 0 or tied_module.__class__.__name__ in no_split_module_classes:
  1257. # can't break this one.
  1258. continue
  1259. if verbose:
  1260. print(f"Splitting {tied_module_name}.")
  1261. tied_module_children = list(tied_module.named_parameters(recurse=False)) + tied_module_children
  1262. tied_module_children = [(f"{tied_module_name}.{n}", v) for n, v in tied_module_children]
  1263. tied_module_index = [i for i, (n, _) in enumerate(modules_to_treat) if n == tied_module_name][0]
  1264. modules_to_treat = (
  1265. [(name, module)]
  1266. + modules_to_treat[:tied_module_index]
  1267. + tied_module_children
  1268. + modules_to_treat[tied_module_index + 1 :]
  1269. )
  1270. # Update the max layer size.
  1271. max_layer_size, max_layer_names = get_max_layer_size(
  1272. [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
  1273. module_sizes,
  1274. no_split_module_classes,
  1275. )
  1276. split_happened = True
  1277. break
  1278. if split_happened:
  1279. continue
  1280. # If the tied module is not split, we go to the next device
  1281. if verbose:
  1282. print("None of the tied module can be split, going to the next device.")
  1283. # The current module itself doesn't fit, so we have to split it or go to the next device.
  1284. if device_memory_used[device] + module_size >= current_max_size:
  1285. # Split or not split?
  1286. modules_children = (
  1287. []
  1288. if isinstance(module, nn.Parameter) or isinstance(module, torch.Tensor)
  1289. else list(module.named_children())
  1290. )
  1291. if verbose:
  1292. print(
  1293. f"Not enough space on {devices[current_device]} to put {name} (space available "
  1294. f"{current_max_size - device_memory_used[device]}, module size {module_size})."
  1295. )
  1296. if len(modules_children) == 0 or module.__class__.__name__ in no_split_module_classes:
  1297. # -> no split, we go to the next device
  1298. if verbose:
  1299. print("This module cannot be split, going to the next device.")
  1300. else:
  1301. # -> split, we replace the module studied by its children + parameters
  1302. if verbose:
  1303. print(f"Splitting {name}.")
  1304. modules_children = list(module.named_parameters(recurse=False)) + modules_children
  1305. modules_to_treat = [(f"{name}.{n}", v) for n, v in modules_children] + modules_to_treat
  1306. # Update the max layer size.
  1307. max_layer_size, max_layer_names = get_max_layer_size(
  1308. [(n, m) for n, m in modules_to_treat if isinstance(m, torch.nn.Module)],
  1309. module_sizes,
  1310. no_split_module_classes,
  1311. )
  1312. continue
  1313. # If no module is assigned to the current device, we attempt to allocate a fallback module
  1314. # if fallback_allocation is enabled.
  1315. if device_memory_used[device] == 0 and fallback_allocation and device != "disk":
  1316. # We try to allocate a module that fits in the size limit using BFS.
  1317. # Recompute the current max size as we need to consider the current module as well.
  1318. current_max_size = max_memory[device] - max(max_layer_size, module_size_with_ties)
  1319. fallback_module_name, fallback_module, remaining_modules = fallback_allocate(
  1320. modules_to_treat,
  1321. module_sizes,
  1322. current_max_size - device_memory_used[device],
  1323. no_split_module_classes,
  1324. tied_parameters,
  1325. )
  1326. # use the next iteration to put the fallback module on the next device to avoid code duplication
  1327. if fallback_module is not None:
  1328. modules_to_treat = [(fallback_module_name, fallback_module)] + [(name, module)] + remaining_modules
  1329. continue
  1330. if device_memory_used[device] == 0:
  1331. device_minimum_assignment_memory[device] = module_size_with_ties + current_memory_reserved
  1332. # Neither the current module nor any tied modules can be split, so we move to the next device.
  1333. device_memory_used[device] = device_memory_used[device] + current_memory_reserved
  1334. current_device += 1
  1335. modules_to_treat = [(name, module)] + modules_to_treat
  1336. device_memory_used = {device: mem for device, mem in device_memory_used.items() if mem > 0}
  1337. if clean_result:
  1338. device_map = clean_device_map(device_map)
  1339. non_gpu_buffer_size = device_buffer_sizes.get("cpu", 0) + device_buffer_sizes.get("disk", 0)
  1340. if non_gpu_buffer_size > 0 and not offload_buffers:
  1341. is_buffer_fit_any_gpu = False
  1342. for gpu_device, gpu_max_memory in max_memory.items():
  1343. if gpu_device == "cpu" or gpu_device == "disk":
  1344. continue
  1345. if not is_buffer_fit_any_gpu:
  1346. gpu_memory_used = device_memory_used.get(gpu_device, 0)
  1347. if gpu_max_memory >= non_gpu_buffer_size + gpu_memory_used:
  1348. is_buffer_fit_any_gpu = True
  1349. if len(gpus) > 0 and not is_buffer_fit_any_gpu:
  1350. warnings.warn(
  1351. f"Current model requires {non_gpu_buffer_size} bytes of buffer for offloaded layers, which seems does "
  1352. f"not fit any GPU's remaining memory. If you are experiencing a OOM later, please consider using "
  1353. f"offload_buffers=True."
  1354. )
  1355. if device_minimum_assignment_memory:
  1356. devices_info = "\n".join(
  1357. f" - {device}: {mem} bytes required" for device, mem in device_minimum_assignment_memory.items()
  1358. )
  1359. logger.info(
  1360. f"Based on the current allocation process, no modules could be assigned to the following devices due to "
  1361. f"insufficient memory:\n"
  1362. f"{devices_info}\n"
  1363. f"These minimum requirements are specific to this allocation attempt and may vary. Consider increasing "
  1364. f"the available memory for these devices to at least the specified minimum, or adjusting the model config."
  1365. )
  1366. return device_map
  1367. def check_device_map(model: nn.Module, device_map: dict[str, Union[int, str, torch.device]]):
  1368. """
  1369. Checks a device map covers everything in a given model.
  1370. Args:
  1371. model (`torch.nn.Module`): The model to check the device map against.
  1372. device_map (`Dict[str, Union[int, str, torch.device]]`): The device map to check.
  1373. """
  1374. all_module_names = dict(model.named_modules())
  1375. invalid_keys = [k for k in device_map if k != "" and k not in all_module_names]
  1376. if invalid_keys:
  1377. warnings.warn(
  1378. f"The following device_map keys do not match any submodules in the model: {invalid_keys}", UserWarning
  1379. )
  1380. all_model_tensors = [name for name, _ in model.state_dict().items()]
  1381. for module_name in device_map.keys():
  1382. if module_name == "":
  1383. all_model_tensors.clear()
  1384. break
  1385. else:
  1386. all_model_tensors = [
  1387. name
  1388. for name in all_model_tensors
  1389. if not name == module_name and not name.startswith(module_name + ".")
  1390. ]
  1391. if len(all_model_tensors) > 0:
  1392. non_covered_params = ", ".join(all_model_tensors)
  1393. raise ValueError(
  1394. f"The device_map provided does not give any device for the following parameters: {non_covered_params}"
  1395. )
  1396. def load_state_dict(checkpoint_file, device_map=None):
  1397. """
  1398. Load a checkpoint from a given file. If the checkpoint is in the safetensors format and a device map is passed, the
  1399. weights can be fast-loaded directly on the GPU.
  1400. Args:
  1401. checkpoint_file (`str`): The path to the checkpoint to load.
  1402. device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
  1403. A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
  1404. name, once a given module name is inside, every submodule of it will be sent to the same device.
  1405. """
  1406. if checkpoint_file.endswith(".safetensors"):
  1407. with safe_open(checkpoint_file, framework="pt") as f:
  1408. metadata = f.metadata()
  1409. weight_names = f.keys()
  1410. if metadata is None:
  1411. logger.warning(
  1412. f"The safetensors archive passed at {checkpoint_file} does not contain metadata. "
  1413. "Make sure to save your model with the `save_pretrained` method. Defaulting to 'pt' metadata."
  1414. )
  1415. metadata = {"format": "pt"}
  1416. if metadata.get("format") not in ["pt", "tf", "flax"]:
  1417. raise OSError(
  1418. f"The safetensors archive passed at {checkpoint_file} does not contain the valid metadata. Make sure "
  1419. "you save your model with the `save_pretrained` method."
  1420. )
  1421. elif metadata["format"] != "pt":
  1422. raise ValueError(f"The checkpoint passed was saved with {metadata['format']}, we need a the pt format.")
  1423. if device_map is None:
  1424. return safe_load_file(checkpoint_file)
  1425. else:
  1426. # if we only have one device we can load everything directly
  1427. if len(set(device_map.values())) == 1:
  1428. device = list(device_map.values())[0]
  1429. target_device = device
  1430. if isinstance(device, int):
  1431. if is_npu_available():
  1432. target_device = f"npu:{device}"
  1433. elif is_hpu_available():
  1434. target_device = "hpu"
  1435. return safe_load_file(checkpoint_file, device=target_device)
  1436. devices = list(set(device_map.values()) - {"disk"})
  1437. # cpu device should always exist as fallback option
  1438. if "cpu" not in devices:
  1439. devices.append("cpu")
  1440. # For each device, get the weights that go there
  1441. device_weights = {device: [] for device in devices}
  1442. for module_name, device in device_map.items():
  1443. if device in devices:
  1444. device_weights[device].extend(
  1445. [k for k in weight_names if k == module_name or k.startswith(module_name + ".")]
  1446. )
  1447. # all weights that haven't defined a device should be loaded on CPU
  1448. device_weights["cpu"].extend([k for k in weight_names if k not in sum(device_weights.values(), [])])
  1449. tensors = {}
  1450. if is_tqdm_available():
  1451. progress_bar = tqdm(
  1452. main_process_only=False,
  1453. total=sum([len(device_weights[device]) for device in devices]),
  1454. unit="w",
  1455. smoothing=0,
  1456. leave=False,
  1457. )
  1458. else:
  1459. progress_bar = None
  1460. for device in devices:
  1461. target_device = device
  1462. if isinstance(device, int):
  1463. if is_npu_available():
  1464. target_device = f"npu:{device}"
  1465. elif is_hpu_available():
  1466. target_device = "hpu"
  1467. with safe_open(checkpoint_file, framework="pt", device=target_device) as f:
  1468. for key in device_weights[device]:
  1469. if progress_bar is not None:
  1470. progress_bar.set_postfix(dev=device, refresh=False)
  1471. progress_bar.set_description(key)
  1472. tensors[key] = f.get_tensor(key)
  1473. if progress_bar is not None:
  1474. progress_bar.update()
  1475. if progress_bar is not None:
  1476. progress_bar.close()
  1477. return tensors
  1478. else:
  1479. return torch.load(checkpoint_file, map_location=torch.device("cpu"), weights_only=True)
  1480. def get_state_dict_offloaded_model(model: nn.Module):
  1481. """
  1482. Returns the state dictionary for an offloaded model via iterative onloading
  1483. Args:
  1484. model (`torch.nn.Module`):
  1485. The offloaded model we want to save
  1486. """
  1487. state_dict = {}
  1488. placeholders = set()
  1489. for name, module in model.named_modules():
  1490. if name == "":
  1491. continue
  1492. try:
  1493. with align_module_device(module, "cpu"):
  1494. module_state_dict = module.state_dict()
  1495. except MemoryError:
  1496. raise MemoryError("Offloaded module must fit in CPU memory to call save_model!") from None
  1497. for key in module_state_dict:
  1498. # ignore placeholder parameters that are still on the meta device
  1499. if module_state_dict[key].device == torch.device("meta"):
  1500. placeholders.add(name + f".{key}")
  1501. continue
  1502. params = module_state_dict[key]
  1503. state_dict[name + f".{key}"] = params.to("cpu") # move buffers to cpu
  1504. for key in placeholders.copy():
  1505. if key in state_dict:
  1506. placeholders.remove(key)
  1507. if placeholders:
  1508. logger.warning(f"The following tensors were not saved because they were still on meta device: {placeholders}")
  1509. return state_dict
  1510. def get_state_dict_from_offload(
  1511. module: nn.Module,
  1512. module_name: str,
  1513. state_dict: dict[str, Union[str, torch.tensor]],
  1514. device_to_put_offload: Union[int, str, torch.device] = "cpu",
  1515. ):
  1516. """
  1517. Retrieve the state dictionary (with parameters) from an offloaded module and load into a specified device (defaults
  1518. to cpu).
  1519. Args:
  1520. module: (`torch.nn.Module`):
  1521. The module we want to retrieve a state dictionary from
  1522. module_name: (`str`):
  1523. The name of the module of interest
  1524. state_dict (`Dict[str, Union[int, str, torch.device]]`):
  1525. Dictionary of {module names: parameters}
  1526. device_to_put_offload (`Union[int, str, torch.device]`):
  1527. Device to load offloaded parameters into, defaults to the cpu.
  1528. """
  1529. root = module_name[: module_name.rfind(".")] # module name without .weight or .bias
  1530. # do not move parameters if the module is not offloaded
  1531. if not has_offloaded_params(module):
  1532. device_to_put_offload = None
  1533. # assign the device to which the offloaded parameters will be sent
  1534. with align_module_device(module, device_to_put_offload):
  1535. for m_key, params in module.state_dict().items():
  1536. if (root + f".{m_key}") in state_dict:
  1537. state_dict[root + f".{m_key}"] = params
  1538. return state_dict
  1539. def load_checkpoint_in_model(
  1540. model: nn.Module,
  1541. checkpoint: Union[str, os.PathLike],
  1542. device_map: Optional[dict[str, Union[int, str, torch.device]]] = None,
  1543. offload_folder: Optional[Union[str, os.PathLike]] = None,
  1544. dtype: Optional[Union[str, torch.dtype]] = None,
  1545. offload_state_dict: bool = False,
  1546. offload_buffers: bool = False,
  1547. keep_in_fp32_modules: Optional[list[str]] = None,
  1548. offload_8bit_bnb: bool = False,
  1549. strict: bool = False,
  1550. full_state_dict: bool = True,
  1551. broadcast_from_rank0: bool = False,
  1552. ):
  1553. """
  1554. Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
  1555. loaded.
  1556. <Tip warning={true}>
  1557. Once loaded across devices, you still need to call [`dispatch_model`] on your model to make it able to run. To
  1558. group the checkpoint loading and dispatch in one single call, use [`load_checkpoint_and_dispatch`].
  1559. </Tip>
  1560. Args:
  1561. model (`torch.nn.Module`):
  1562. The model in which we want to load a checkpoint.
  1563. checkpoint (`str` or `os.PathLike`):
  1564. The folder checkpoint to load. It can be:
  1565. - a path to a file containing a whole model state dict
  1566. - a path to a `.json` file containing the index to a sharded checkpoint
  1567. - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
  1568. - a path to a folder containing a unique pytorch_model.bin or a model.safetensors file.
  1569. device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
  1570. A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
  1571. name, once a given module name is inside, every submodule of it will be sent to the same device.
  1572. offload_folder (`str` or `os.PathLike`, *optional*):
  1573. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
  1574. dtype (`str` or `torch.dtype`, *optional*):
  1575. If provided, the weights will be converted to that type when loaded.
  1576. offload_state_dict (`bool`, *optional*, defaults to `False`):
  1577. If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
  1578. the weight of the CPU state dict + the biggest shard does not fit.
  1579. offload_buffers (`bool`, *optional*, defaults to `False`):
  1580. Whether or not to include the buffers in the weights offloaded to disk.
  1581. keep_in_fp32_modules(`List[str]`, *optional*):
  1582. A list of the modules that we keep in `torch.float32` dtype.
  1583. offload_8bit_bnb (`bool`, *optional*):
  1584. Whether or not to enable offload of 8-bit modules on cpu/disk.
  1585. strict (`bool`, *optional*, defaults to `False`):
  1586. Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
  1587. state_dict.
  1588. full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the
  1589. loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict.
  1590. broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed
  1591. `ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors
  1592. in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable)
  1593. according to the local shards in the model.
  1594. """
  1595. if offload_8bit_bnb:
  1596. from .bnb import quantize_and_offload_8bit
  1597. tied_params = find_tied_parameters(model)
  1598. if check_tied_parameters_in_config(model) and len(tied_params) == 0:
  1599. logger.warning(
  1600. "The model weights are not tied. Please use the `tie_weights` method before using the `infer_auto_device` function."
  1601. )
  1602. if device_map is not None:
  1603. check_tied_parameters_on_same_device(tied_params, device_map)
  1604. if offload_folder is None and device_map is not None and "disk" in device_map.values():
  1605. raise ValueError(
  1606. "At least one of the model submodule will be offloaded to disk, please pass along an `offload_folder`."
  1607. )
  1608. elif offload_folder is not None and device_map is not None and "disk" in device_map.values():
  1609. os.makedirs(offload_folder, exist_ok=True)
  1610. if isinstance(dtype, str):
  1611. # We accept "torch.float16" or just "float16"
  1612. dtype = dtype.replace("torch.", "")
  1613. dtype = getattr(torch, dtype)
  1614. checkpoint_files = None
  1615. index_filename = None
  1616. if os.path.isfile(checkpoint):
  1617. if str(checkpoint).endswith(".json"):
  1618. index_filename = checkpoint
  1619. else:
  1620. checkpoint_files = [checkpoint]
  1621. elif os.path.isdir(checkpoint):
  1622. # check if the whole state dict is present
  1623. potential_state_bin = [f for f in os.listdir(checkpoint) if f == WEIGHTS_NAME]
  1624. potential_state_safetensor = [f for f in os.listdir(checkpoint) if f == SAFE_WEIGHTS_NAME]
  1625. if len(potential_state_bin) == 1:
  1626. checkpoint_files = [os.path.join(checkpoint, potential_state_bin[0])]
  1627. elif len(potential_state_safetensor) == 1:
  1628. checkpoint_files = [os.path.join(checkpoint, potential_state_safetensor[0])]
  1629. else:
  1630. # otherwise check for sharded checkpoints
  1631. potential_index = [f for f in os.listdir(checkpoint) if f.endswith(".index.json")]
  1632. if len(potential_index) == 0:
  1633. raise ValueError(
  1634. f"{checkpoint} is not a folder containing a `.index.json` file or a {WEIGHTS_NAME} or a {SAFE_WEIGHTS_NAME} file"
  1635. )
  1636. elif len(potential_index) == 1:
  1637. index_filename = os.path.join(checkpoint, potential_index[0])
  1638. else:
  1639. raise ValueError(
  1640. f"{checkpoint} containing more than one `.index.json` file, delete the irrelevant ones."
  1641. )
  1642. else:
  1643. raise ValueError(
  1644. "`checkpoint` should be the path to a file containing a whole state dict, or the index of a sharded "
  1645. f"checkpoint, or a folder containing a sharded checkpoint or the whole state dict, but got {checkpoint}."
  1646. )
  1647. if index_filename is not None:
  1648. checkpoint_folder = os.path.split(index_filename)[0]
  1649. with open(index_filename) as f:
  1650. index = json.loads(f.read())
  1651. if "weight_map" in index:
  1652. index = index["weight_map"]
  1653. checkpoint_files = sorted(list(set(index.values())))
  1654. checkpoint_files = [os.path.join(checkpoint_folder, f) for f in checkpoint_files]
  1655. # Logic for missing/unexepected keys goes here.
  1656. offload_index = {}
  1657. if offload_state_dict:
  1658. state_dict_folder = tempfile.mkdtemp()
  1659. state_dict_index = {}
  1660. unexpected_keys = set()
  1661. model_keys = set(model.state_dict().keys())
  1662. buffer_names = [name for name, _ in model.named_buffers()]
  1663. model_devices = {t.device for t in model.state_dict().values() if isinstance(t, torch.Tensor)}
  1664. model_physical_devices = model_devices - {torch.device("meta")}
  1665. for checkpoint_file in checkpoint_files:
  1666. if device_map is None:
  1667. # exception for multi-device loading was made for the meta device in torch v2.7.0
  1668. # https://github.com/pytorch/pytorch/blob/v2.6.0/torch/distributed/checkpoint/state_dict.py#L557-L563
  1669. # https://github.com/pytorch/pytorch/blob/v2.7.0-rc2/torch/distributed/checkpoint/state_dict.py#L575-L587
  1670. if is_torch_version(">=", "2.2.0") and (
  1671. (is_torch_version(">=", "2.7.0") and len(model_physical_devices) <= 1) or len(model_devices) <= 1
  1672. ):
  1673. from torch.distributed.checkpoint.state_dict import StateDictOptions, set_model_state_dict
  1674. broadcast_from_rank0 &= is_torch_version(">=", "2.4.0")
  1675. loaded_checkpoint = (
  1676. load_state_dict(checkpoint_file, device_map=device_map)
  1677. if not broadcast_from_rank0 or dist.get_rank() == 0
  1678. else {}
  1679. )
  1680. set_model_state_dict(
  1681. model,
  1682. loaded_checkpoint,
  1683. options=StateDictOptions(
  1684. full_state_dict=full_state_dict,
  1685. strict=strict,
  1686. **({"broadcast_from_rank0": broadcast_from_rank0} if is_torch_version(">=", "2.4.0") else {}),
  1687. ),
  1688. )
  1689. else:
  1690. loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
  1691. model.load_state_dict(loaded_checkpoint, strict=strict)
  1692. unexpected_keys.update(set(loaded_checkpoint.keys()) - model_keys)
  1693. else:
  1694. loaded_checkpoint = load_state_dict(checkpoint_file, device_map=device_map)
  1695. for param_name, param in loaded_checkpoint.items():
  1696. # skip SCB parameter (for 8-bit serialization)
  1697. if "SCB" in param_name:
  1698. continue
  1699. if param_name not in model_keys:
  1700. unexpected_keys.add(param_name)
  1701. if not strict:
  1702. continue # Skip loading this parameter.
  1703. module_name = param_name
  1704. while len(module_name) > 0 and module_name not in device_map:
  1705. module_name = ".".join(module_name.split(".")[:-1])
  1706. if module_name == "" and "" not in device_map:
  1707. # TODO: group all errors and raise at the end.
  1708. raise ValueError(f"{param_name} doesn't have any device set.")
  1709. param_device = device_map[module_name]
  1710. new_dtype = dtype
  1711. if dtype is not None and torch.is_floating_point(param):
  1712. if keep_in_fp32_modules is not None and dtype == torch.float16:
  1713. proceed = False
  1714. for key in keep_in_fp32_modules:
  1715. if ((key in param_name) and (key + "." in param_name)) or key == param_name:
  1716. proceed = True
  1717. break
  1718. if proceed:
  1719. new_dtype = torch.float32
  1720. if "weight" in param_name and param_name.replace("weight", "SCB") in loaded_checkpoint.keys():
  1721. if param.dtype == torch.int8:
  1722. fp16_statistics = loaded_checkpoint[param_name.replace("weight", "SCB")]
  1723. else:
  1724. fp16_statistics = None
  1725. if param_device == "disk":
  1726. if offload_buffers or param_name not in buffer_names:
  1727. if new_dtype is None:
  1728. new_dtype = param.dtype
  1729. if offload_8bit_bnb:
  1730. quantize_and_offload_8bit(
  1731. model, param, param_name, new_dtype, offload_folder, offload_index, fp16_statistics
  1732. )
  1733. continue
  1734. else:
  1735. set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
  1736. offload_weight(param, param_name, offload_folder, index=offload_index)
  1737. elif param_device == "cpu" and offload_state_dict:
  1738. if new_dtype is None:
  1739. new_dtype = param.dtype
  1740. if offload_8bit_bnb:
  1741. quantize_and_offload_8bit(
  1742. model, param, param_name, new_dtype, state_dict_folder, state_dict_index, fp16_statistics
  1743. )
  1744. else:
  1745. set_module_tensor_to_device(model, param_name, "meta", dtype=new_dtype)
  1746. offload_weight(param, param_name, state_dict_folder, index=state_dict_index)
  1747. else:
  1748. set_module_tensor_to_device(
  1749. model,
  1750. param_name,
  1751. param_device,
  1752. value=param,
  1753. dtype=new_dtype,
  1754. fp16_statistics=fp16_statistics,
  1755. )
  1756. # Force Python to clean up.
  1757. del loaded_checkpoint
  1758. gc.collect()
  1759. if not strict and len(unexpected_keys) > 0:
  1760. logger.warning(
  1761. f"Some weights of the model checkpoint at {checkpoint} were not used when"
  1762. f" initializing {model.__class__.__name__}: {unexpected_keys}. This may or may not be an issue - make sure that the checkpoint does not have unnecessary parameters, or that the model definition correctly corresponds to the checkpoint."
  1763. )
  1764. save_offload_index(offload_index, offload_folder)
  1765. # Load back offloaded state dict on CPU
  1766. if offload_state_dict:
  1767. load_offloaded_weights(model, state_dict_index, state_dict_folder)
  1768. shutil.rmtree(state_dict_folder)
  1769. retie_parameters(model, tied_params)
  1770. def get_mixed_precision_context_manager(native_amp: bool = False, autocast_kwargs: AutocastKwargs = None):
  1771. """
  1772. Return a context manager for autocasting mixed precision
  1773. Args:
  1774. native_amp (`bool`, *optional*, defaults to False):
  1775. Whether mixed precision is actually enabled.
  1776. cache_enabled (`bool`, *optional*, defaults to True):
  1777. Whether the weight cache inside autocast should be enabled.
  1778. """
  1779. state = AcceleratorState()
  1780. if autocast_kwargs is None:
  1781. autocast_kwargs = {}
  1782. else:
  1783. autocast_kwargs = autocast_kwargs.to_kwargs()
  1784. if native_amp:
  1785. device_type = (
  1786. "cuda"
  1787. if (state.distributed_type == DistributedType.XLA and is_torch_xla_available(check_is_gpu=True))
  1788. else state.device.type
  1789. )
  1790. if state.mixed_precision == "fp16":
  1791. return torch.autocast(device_type=device_type, dtype=torch.float16, **autocast_kwargs)
  1792. elif state.mixed_precision in ["bf16", "fp8"] and state.distributed_type in [
  1793. DistributedType.NO,
  1794. DistributedType.MULTI_CPU,
  1795. DistributedType.MULTI_GPU,
  1796. DistributedType.MULTI_MLU,
  1797. DistributedType.MULTI_SDAA,
  1798. DistributedType.MULTI_MUSA,
  1799. DistributedType.MULTI_NPU,
  1800. DistributedType.MULTI_XPU,
  1801. DistributedType.MULTI_HPU,
  1802. DistributedType.FSDP,
  1803. DistributedType.XLA,
  1804. ]:
  1805. return torch.autocast(device_type=device_type, dtype=torch.bfloat16, **autocast_kwargs)
  1806. else:
  1807. return torch.autocast(device_type=device_type, **autocast_kwargs)
  1808. else:
  1809. return contextlib.nullcontext()
  1810. def get_grad_scaler(distributed_type: DistributedType = None, **kwargs):
  1811. """
  1812. A generic helper which will initialize the correct `GradScaler` implementation based on the environment and return
  1813. it.
  1814. Args:
  1815. distributed_type (`DistributedType`, *optional*, defaults to None):
  1816. The type of distributed environment.
  1817. kwargs:
  1818. Additional arguments for the utilized `GradScaler` constructor.
  1819. """
  1820. if distributed_type == DistributedType.FSDP:
  1821. from torch.distributed.fsdp.sharded_grad_scaler import ShardedGradScaler
  1822. return ShardedGradScaler(**kwargs)
  1823. if is_torch_xla_available(check_is_gpu=True):
  1824. import torch_xla.amp as xamp
  1825. return xamp.GradScaler(**kwargs)
  1826. elif is_mlu_available():
  1827. return torch.mlu.amp.GradScaler(**kwargs)
  1828. elif is_sdaa_available():
  1829. return torch.sdaa.amp.GradScaler(**kwargs)
  1830. elif is_musa_available():
  1831. return torch.musa.amp.GradScaler(**kwargs)
  1832. elif is_npu_available():
  1833. return torch.npu.amp.GradScaler(**kwargs)
  1834. elif is_hpu_available():
  1835. return torch.amp.GradScaler("hpu", **kwargs)
  1836. elif is_xpu_available():
  1837. return torch.amp.GradScaler("xpu", **kwargs)
  1838. elif is_mps_available():
  1839. if not is_torch_version(">=", "2.8.0"):
  1840. raise ValueError("Grad Scaler with MPS device requires a Pytorch >= 2.8.0")
  1841. return torch.amp.GradScaler("mps", **kwargs)
  1842. else:
  1843. if is_torch_version(">=", "2.3"):
  1844. return torch.amp.GradScaler("cuda", **kwargs)
  1845. else:
  1846. return torch.cuda.amp.GradScaler(**kwargs)
  1847. def has_offloaded_params(module: torch.nn.Module) -> bool:
  1848. """
  1849. Checks if a module has offloaded parameters by checking if the given module has a AlignDevicesHook attached with
  1850. offloading enabled
  1851. Args:
  1852. module (`torch.nn.Module`): The module to check for an offload hook.
  1853. Returns:
  1854. bool: `True` if the module has an offload hook and offloading is enabled, `False` otherwise.
  1855. """
  1856. from ..hooks import AlignDevicesHook # avoid circular import
  1857. return hasattr(module, "_hf_hook") and isinstance(module._hf_hook, AlignDevicesHook) and module._hf_hook.offload
  1858. @contextlib.contextmanager
  1859. def align_module_device(module: torch.nn.Module, execution_device: Optional[torch.device] = None):
  1860. """
  1861. Context manager that moves a module's parameters to the specified execution device.
  1862. Args:
  1863. module (`torch.nn.Module`):
  1864. Module with parameters to align.
  1865. execution_device (`torch.device`, *optional*):
  1866. If provided, overrides the module's execution device within the context. Otherwise, use hook execution
  1867. device or pass
  1868. """
  1869. if has_offloaded_params(module):
  1870. if execution_device is not None:
  1871. original_device = module._hf_hook.execution_device
  1872. module._hf_hook.execution_device = execution_device
  1873. try:
  1874. module._hf_hook.pre_forward(module)
  1875. yield
  1876. finally:
  1877. module._hf_hook.post_forward(module, None)
  1878. if execution_device is not None:
  1879. module._hf_hook.execution_device = original_device
  1880. elif execution_device is not None:
  1881. devices = {name: param.device for name, param in module.named_parameters(recurse=False)}
  1882. try:
  1883. for name in devices:
  1884. set_module_tensor_to_device(module, name, execution_device)
  1885. yield
  1886. finally:
  1887. for name, device in devices.items():
  1888. set_module_tensor_to_device(module, name, device)
  1889. else:
  1890. yield