big_modeling.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import logging
  15. import os
  16. import re
  17. from contextlib import contextmanager
  18. from functools import wraps
  19. from typing import Optional, Union
  20. import torch
  21. import torch.nn as nn
  22. from .hooks import (
  23. AlignDevicesHook,
  24. CpuOffload,
  25. LayerwiseCastingHook,
  26. UserCpuOffloadHook,
  27. add_hook_to_module,
  28. attach_align_device_hook,
  29. attach_align_device_hook_on_blocks,
  30. )
  31. from .utils import (
  32. OffloadedWeightsLoader,
  33. check_cuda_p2p_ib_support,
  34. check_device_map,
  35. extract_submodules_state_dict,
  36. find_tied_parameters,
  37. get_balanced_memory,
  38. infer_auto_device_map,
  39. is_bnb_available,
  40. is_mlu_available,
  41. is_musa_available,
  42. is_npu_available,
  43. is_sdaa_available,
  44. is_xpu_available,
  45. load_checkpoint_in_model,
  46. offload_state_dict,
  47. parse_flag_from_env,
  48. retie_parameters,
  49. )
  50. from .utils.constants import SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING
  51. from .utils.other import recursive_getattr
  52. logger = logging.getLogger(__name__)
  53. @contextmanager
  54. def init_empty_weights(include_buffers: Optional[bool] = None):
  55. """
  56. A context manager under which models are initialized with all parameters on the meta device, therefore creating an
  57. empty model. Useful when just initializing the model would blow the available RAM.
  58. Args:
  59. include_buffers (`bool`, *optional*):
  60. Whether or not to also put all buffers on the meta device while initializing.
  61. Example:
  62. ```python
  63. import torch.nn as nn
  64. from accelerate import init_empty_weights
  65. # Initialize a model with 100 billions parameters in no time and without using any RAM.
  66. with init_empty_weights():
  67. tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
  68. ```
  69. <Tip warning={true}>
  70. Any model created under this context manager has no weights. As such you can't do something like
  71. `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
  72. Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
  73. called.
  74. </Tip>
  75. """
  76. if include_buffers is None:
  77. include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
  78. with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
  79. yield f
  80. @contextmanager
  81. def init_on_device(device: torch.device, include_buffers: Optional[bool] = None):
  82. """
  83. A context manager under which models are initialized with all parameters on the specified device.
  84. Args:
  85. device (`torch.device`):
  86. Device to initialize all parameters on.
  87. include_buffers (`bool`, *optional*):
  88. Whether or not to also put all buffers on the meta device while initializing.
  89. Example:
  90. ```python
  91. import torch.nn as nn
  92. from accelerate import init_on_device
  93. with init_on_device(device=torch.device("cuda")):
  94. tst = nn.Linear(100, 100) # on `cuda` device
  95. ```
  96. """
  97. if include_buffers is None:
  98. include_buffers = parse_flag_from_env("ACCELERATE_INIT_INCLUDE_BUFFERS", False)
  99. if include_buffers:
  100. with device:
  101. yield
  102. return
  103. old_register_parameter = nn.Module.register_parameter
  104. if include_buffers:
  105. old_register_buffer = nn.Module.register_buffer
  106. def register_empty_parameter(module, name, param):
  107. old_register_parameter(module, name, param)
  108. if param is not None:
  109. param_cls = type(module._parameters[name])
  110. kwargs = module._parameters[name].__dict__
  111. kwargs["requires_grad"] = param.requires_grad
  112. module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
  113. def register_empty_buffer(module, name, buffer, persistent=True):
  114. old_register_buffer(module, name, buffer, persistent=persistent)
  115. if buffer is not None:
  116. module._buffers[name] = module._buffers[name].to(device)
  117. # Patch tensor creation
  118. if include_buffers:
  119. tensor_constructors_to_patch = {
  120. torch_function_name: getattr(torch, torch_function_name)
  121. for torch_function_name in ["empty", "zeros", "ones", "full"]
  122. }
  123. else:
  124. tensor_constructors_to_patch = {}
  125. def patch_tensor_constructor(fn):
  126. def wrapper(*args, **kwargs):
  127. kwargs["device"] = device
  128. return fn(*args, **kwargs)
  129. return wrapper
  130. try:
  131. nn.Module.register_parameter = register_empty_parameter
  132. if include_buffers:
  133. nn.Module.register_buffer = register_empty_buffer
  134. for torch_function_name in tensor_constructors_to_patch.keys():
  135. setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
  136. yield
  137. finally:
  138. nn.Module.register_parameter = old_register_parameter
  139. if include_buffers:
  140. nn.Module.register_buffer = old_register_buffer
  141. for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
  142. setattr(torch, torch_function_name, old_torch_function)
  143. def cpu_offload(
  144. model: nn.Module,
  145. execution_device: Optional[torch.device] = None,
  146. offload_buffers: bool = False,
  147. state_dict: Optional[dict[str, torch.Tensor]] = None,
  148. preload_module_classes: Optional[list[str]] = None,
  149. ):
  150. """
  151. Activates full CPU offload for a model. As a result, all parameters of the model will be offloaded and only one
  152. copy of the state dict of the model will be kept. During the forward pass, parameters will be extracted from that
  153. state dict and put on the execution device passed as they are needed, then offloaded again.
  154. Args:
  155. model (`torch.nn.Module`):
  156. The model to offload.
  157. execution_device (`torch.device`, *optional*):
  158. The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
  159. model first parameter device.
  160. offload_buffers (`bool`, *optional*, defaults to `False`):
  161. Whether or not to offload the buffers with the model parameters.
  162. state_dict (`Dict[str, torch.Tensor]`, *optional*):
  163. The state dict of the model that will be kept on CPU.
  164. preload_module_classes (`List[str]`, *optional*):
  165. A list of classes whose instances should load all their weights (even in the submodules) at the beginning
  166. of the forward. This should only be used for classes that have submodules which are registered but not
  167. called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
  168. `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
  169. """
  170. if execution_device is None:
  171. execution_device = next(iter(model.parameters())).device
  172. if state_dict is None:
  173. state_dict = {n: p.to("cpu") for n, p in model.state_dict().items()}
  174. add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
  175. attach_align_device_hook(
  176. model,
  177. execution_device=execution_device,
  178. offload=True,
  179. offload_buffers=offload_buffers,
  180. weights_map=state_dict,
  181. preload_module_classes=preload_module_classes,
  182. )
  183. return model
  184. def cpu_offload_with_hook(
  185. model: torch.nn.Module,
  186. execution_device: Optional[Union[int, str, torch.device]] = None,
  187. prev_module_hook: Optional[UserCpuOffloadHook] = None,
  188. ):
  189. """
  190. Offloads a model on the CPU and puts it back to an execution device when executed. The difference with
  191. [`cpu_offload`] is that the model stays on the execution device after the forward and is only offloaded again when
  192. the `offload` method of the returned `hook` is called. Useful for pipelines running a model in a loop.
  193. Args:
  194. model (`torch.nn.Module`):
  195. The model to offload.
  196. execution_device(`str`, `int` or `torch.device`, *optional*):
  197. The device on which the model should be executed. Will default to the MPS device if it's available, then
  198. GPU 0 if there is a GPU, and finally to the CPU.
  199. prev_module_hook (`UserCpuOffloadHook`, *optional*):
  200. The hook sent back by this function for a previous model in the pipeline you are running. If passed, its
  201. offload method will be called just before the forward of the model to which this hook is attached.
  202. Example:
  203. ```py
  204. model_1, hook_1 = cpu_offload_with_hook(model_1, cuda_device)
  205. model_2, hook_2 = cpu_offload_with_hook(model_2, cuda_device, prev_module_hook=hook_1)
  206. model_3, hook_3 = cpu_offload_with_hook(model_3, cuda_device, prev_module_hook=hook_2)
  207. hid_1 = model_1(input)
  208. for i in range(50):
  209. # model1 is offloaded on the CPU at the first iteration, model 2 stays on the GPU for this whole loop.
  210. hid_2 = model_2(hid_1)
  211. # model2 is offloaded to the CPU just before this forward.
  212. hid_3 = model_3(hid_3)
  213. # For model3, you need to manually call the hook offload method.
  214. hook_3.offload()
  215. ```
  216. """
  217. hook = CpuOffload(execution_device=execution_device, prev_module_hook=prev_module_hook)
  218. add_hook_to_module(model, hook, append=True)
  219. user_hook = UserCpuOffloadHook(model, hook)
  220. return model, user_hook
  221. def disk_offload(
  222. model: nn.Module,
  223. offload_dir: Union[str, os.PathLike],
  224. execution_device: Optional[torch.device] = None,
  225. offload_buffers: bool = False,
  226. preload_module_classes: Optional[list[str]] = None,
  227. ):
  228. """
  229. Activates full disk offload for a model. As a result, all parameters of the model will be offloaded as
  230. memory-mapped array in a given folder. During the forward pass, parameters will be accessed from that folder and
  231. put on the execution device passed as they are needed, then offloaded again.
  232. Args:
  233. model (`torch.nn.Module`): The model to offload.
  234. offload_dir (`str` or `os.PathLike`):
  235. The folder in which to offload the model weights (or where the model weights are already offloaded).
  236. execution_device (`torch.device`, *optional*):
  237. The device on which the forward pass of the model will be executed (should be a GPU). Will default to the
  238. model's first parameter device.
  239. offload_buffers (`bool`, *optional*, defaults to `False`):
  240. Whether or not to offload the buffers with the model parameters.
  241. preload_module_classes (`List[str]`, *optional*):
  242. A list of classes whose instances should load all their weights (even in the submodules) at the beginning
  243. of the forward. This should only be used for classes that have submodules which are registered but not
  244. called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
  245. `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
  246. """
  247. if not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")):
  248. offload_state_dict(offload_dir, model.state_dict())
  249. if execution_device is None:
  250. execution_device = next(iter(model.parameters())).device
  251. weights_map = OffloadedWeightsLoader(save_folder=offload_dir)
  252. add_hook_to_module(model, AlignDevicesHook(io_same_device=True), append=True)
  253. attach_align_device_hook(
  254. model,
  255. execution_device=execution_device,
  256. offload=True,
  257. offload_buffers=offload_buffers,
  258. weights_map=weights_map,
  259. preload_module_classes=preload_module_classes,
  260. )
  261. return model
  262. def dispatch_model(
  263. model: nn.Module,
  264. device_map: dict[str, Union[str, int, torch.device]],
  265. main_device: Optional[torch.device] = None,
  266. state_dict: Optional[dict[str, torch.Tensor]] = None,
  267. offload_dir: Optional[Union[str, os.PathLike]] = None,
  268. offload_index: Optional[dict[str, str]] = None,
  269. offload_buffers: bool = False,
  270. skip_keys: Optional[Union[str, list[str]]] = None,
  271. preload_module_classes: Optional[list[str]] = None,
  272. force_hooks: bool = False,
  273. ):
  274. """
  275. Dispatches a model according to a given device map. Layers of the model might be spread across GPUs, offloaded on
  276. the CPU or even the disk.
  277. Args:
  278. model (`torch.nn.Module`):
  279. The model to dispatch.
  280. device_map (`Dict[str, Union[str, int, torch.device]]`):
  281. A dictionary mapping module names in the models `state_dict` to the device they should go to. Note that
  282. `"disk"` is accepted even if it's not a proper value for `torch.device`.
  283. main_device (`str`, `int` or `torch.device`, *optional*):
  284. The main execution device. Will default to the first device in the `device_map` different from `"cpu"` or
  285. `"disk"`.
  286. state_dict (`Dict[str, torch.Tensor]`, *optional*):
  287. The state dict of the part of the model that will be kept on CPU.
  288. offload_dir (`str` or `os.PathLike`):
  289. The folder in which to offload the model weights (or where the model weights are already offloaded).
  290. offload_index (`Dict`, *optional*):
  291. A dictionary from weight name to their information (`dtype`/ `shape` or safetensors filename). Will default
  292. to the index saved in `save_folder`.
  293. offload_buffers (`bool`, *optional*, defaults to `False`):
  294. Whether or not to offload the buffers with the model parameters.
  295. skip_keys (`str` or `List[str]`, *optional*):
  296. A list of keys to ignore when moving inputs or outputs between devices.
  297. preload_module_classes (`List[str]`, *optional*):
  298. A list of classes whose instances should load all their weights (even in the submodules) at the beginning
  299. of the forward. This should only be used for classes that have submodules which are registered but not
  300. called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
  301. `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
  302. force_hooks (`bool`, *optional*, defaults to `False`):
  303. Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
  304. single device.
  305. """
  306. # Error early if the device map is incomplete.
  307. check_device_map(model, device_map)
  308. # We need to force hook for quantized model that can't be moved with to()
  309. if getattr(model, "quantization_method", "bitsandbytes") == "bitsandbytes":
  310. # since bnb 0.43.2, we can move 4-bit model
  311. if getattr(model, "is_loaded_in_8bit", False) or (
  312. getattr(model, "is_loaded_in_4bit", False) and not is_bnb_available(min_version="0.43.2")
  313. ):
  314. force_hooks = True
  315. # We attach hooks if the device_map has at least 2 different devices or if
  316. # force_hooks is set to `True`. Otherwise, the model in already loaded
  317. # in the unique device and the user can decide where to dispatch the model.
  318. # If the model is quantized, we always force-dispatch the model
  319. if (len(set(device_map.values())) > 1) or force_hooks:
  320. if main_device is None:
  321. if set(device_map.values()) == {"cpu"} or set(device_map.values()) == {"cpu", "disk"}:
  322. main_device = "cpu"
  323. else:
  324. main_device = [d for d in device_map.values() if d not in ["cpu", "disk"]][0]
  325. if main_device != "cpu":
  326. cpu_modules = [name for name, device in device_map.items() if device == "cpu"]
  327. if state_dict is None and len(cpu_modules) > 0:
  328. state_dict = extract_submodules_state_dict(model.state_dict(), cpu_modules)
  329. disk_modules = [name for name, device in device_map.items() if device == "disk"]
  330. if offload_dir is None and offload_index is None and len(disk_modules) > 0:
  331. raise ValueError(
  332. "We need an `offload_dir` to dispatch this model according to this `device_map`, the following submodules "
  333. f"need to be offloaded: {', '.join(disk_modules)}."
  334. )
  335. if (
  336. len(disk_modules) > 0
  337. and offload_index is None
  338. and (not os.path.isdir(offload_dir) or not os.path.isfile(os.path.join(offload_dir, "index.json")))
  339. ):
  340. disk_state_dict = extract_submodules_state_dict(model.state_dict(), disk_modules)
  341. offload_state_dict(offload_dir, disk_state_dict)
  342. execution_device = {
  343. name: main_device if device in ["cpu", "disk"] else device for name, device in device_map.items()
  344. }
  345. execution_device[""] = main_device
  346. offloaded_devices = ["disk"] if main_device == "cpu" or main_device == "mps" else ["cpu", "disk"]
  347. offload = {name: device in offloaded_devices for name, device in device_map.items()}
  348. save_folder = offload_dir if len(disk_modules) > 0 else None
  349. if state_dict is not None or save_folder is not None or offload_index is not None:
  350. device = main_device if offload_index is not None else None
  351. weights_map = OffloadedWeightsLoader(
  352. state_dict=state_dict, save_folder=save_folder, index=offload_index, device=device
  353. )
  354. else:
  355. weights_map = None
  356. # When dispatching the model's parameters to the devices specified in device_map, we want to avoid allocating memory several times for the
  357. # tied parameters. The dictionary tied_params_map keeps track of the already allocated data for a given tied parameter (represented by its
  358. # original pointer) on each devices.
  359. tied_params = find_tied_parameters(model)
  360. tied_params_map = {}
  361. for group in tied_params:
  362. for param_name in group:
  363. # data_ptr() is enough here, as `find_tied_parameters` finds tied params simply by comparing `param1 is param2`, so we don't need
  364. # to care about views of tensors through storage_offset.
  365. data_ptr = recursive_getattr(model, param_name).data_ptr()
  366. tied_params_map[data_ptr] = {}
  367. # Note: To handle the disk offloading case, we can not simply use weights_map[param_name].data_ptr() as the reference pointer,
  368. # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
  369. attach_align_device_hook_on_blocks(
  370. model,
  371. execution_device=execution_device,
  372. offload=offload,
  373. offload_buffers=offload_buffers,
  374. weights_map=weights_map,
  375. skip_keys=skip_keys,
  376. preload_module_classes=preload_module_classes,
  377. tied_params_map=tied_params_map,
  378. )
  379. # warn if there is any params on the meta device
  380. offloaded_devices_str = " and ".join(
  381. [device for device in set(device_map.values()) if device in ("cpu", "disk")]
  382. )
  383. if len(offloaded_devices_str) > 0:
  384. logger.warning(
  385. f"Some parameters are on the meta device because they were offloaded to the {offloaded_devices_str}."
  386. )
  387. # Attaching the hook may break tied weights, so we retie them
  388. retie_parameters(model, tied_params)
  389. # add warning to cuda and to method
  390. def add_warning(fn, model):
  391. @wraps(fn)
  392. def wrapper(*args, **kwargs):
  393. warning_msg = "You shouldn't move a model that is dispatched using accelerate hooks."
  394. if str(fn.__name__) == "to":
  395. to_device = torch._C._nn._parse_to(*args, **kwargs)[0]
  396. if to_device is not None:
  397. logger.warning(warning_msg)
  398. else:
  399. logger.warning(warning_msg)
  400. for param in model.parameters():
  401. if param.device == torch.device("meta"):
  402. raise RuntimeError("You can't move a model that has some modules offloaded to cpu or disk.")
  403. return fn(*args, **kwargs)
  404. return wrapper
  405. # Make sure to update _accelerate_added_attributes in hooks.py if you add any hook
  406. model.to = add_warning(model.to, model)
  407. if is_npu_available():
  408. model.npu = add_warning(model.npu, model)
  409. elif is_mlu_available():
  410. model.mlu = add_warning(model.mlu, model)
  411. elif is_sdaa_available():
  412. model.sdaa = add_warning(model.sdaa, model)
  413. elif is_musa_available():
  414. model.musa = add_warning(model.musa, model)
  415. elif is_xpu_available():
  416. model.xpu = add_warning(model.xpu, model)
  417. else:
  418. model.cuda = add_warning(model.cuda, model)
  419. # Check if we are using multi-gpus with RTX 4000 series
  420. use_multi_gpu = len([device for device in set(device_map.values()) if device not in ("cpu", "disk")]) > 1
  421. if use_multi_gpu and not check_cuda_p2p_ib_support():
  422. logger.warning(
  423. "We've detected an older driver with an RTX 4000 series GPU. These drivers have issues with P2P. "
  424. "This can affect the multi-gpu inference when using accelerate device_map."
  425. "Please make sure to update your driver to the latest version which resolves this."
  426. )
  427. else:
  428. device = list(device_map.values())[0]
  429. # `torch.Tensor.to(<int num>)` is not supported by `torch_npu` (see this [issue](https://github.com/Ascend/pytorch/issues/16)).
  430. if is_npu_available() and isinstance(device, int):
  431. device = f"npu:{device}"
  432. elif is_mlu_available() and isinstance(device, int):
  433. device = f"mlu:{device}"
  434. elif is_sdaa_available() and isinstance(device, int):
  435. device = f"sdaa:{device}"
  436. elif is_musa_available() and isinstance(device, int):
  437. device = f"musa:{device}"
  438. if device != "disk":
  439. model.to(device)
  440. else:
  441. raise ValueError(
  442. "You are trying to offload the whole model to the disk. Please use the `disk_offload` function instead."
  443. )
  444. # Convert OrderedDict back to dict for easier usage
  445. model.hf_device_map = dict(device_map)
  446. return model
  447. def load_checkpoint_and_dispatch(
  448. model: nn.Module,
  449. checkpoint: Union[str, os.PathLike],
  450. device_map: Optional[Union[str, dict[str, Union[int, str, torch.device]]]] = None,
  451. max_memory: Optional[dict[Union[int, str], Union[int, str]]] = None,
  452. no_split_module_classes: Optional[list[str]] = None,
  453. offload_folder: Optional[Union[str, os.PathLike]] = None,
  454. offload_buffers: bool = False,
  455. dtype: Optional[Union[str, torch.dtype]] = None,
  456. offload_state_dict: Optional[bool] = None,
  457. skip_keys: Optional[Union[str, list[str]]] = None,
  458. preload_module_classes: Optional[list[str]] = None,
  459. force_hooks: bool = False,
  460. strict: bool = False,
  461. full_state_dict: bool = True,
  462. broadcast_from_rank0: bool = False,
  463. ):
  464. """
  465. Loads a (potentially sharded) checkpoint inside a model, potentially sending weights to a given device as they are
  466. loaded and adds the various hooks that will make this model run properly (even if split across devices).
  467. Args:
  468. model (`torch.nn.Module`): The model in which we want to load a checkpoint.
  469. checkpoint (`str` or `os.PathLike`):
  470. The folder checkpoint to load. It can be:
  471. - a path to a file containing a whole model state dict
  472. - a path to a `.json` file containing the index to a sharded checkpoint
  473. - a path to a folder containing a unique `.index.json` file and the shards of a checkpoint.
  474. device_map (`Dict[str, Union[int, str, torch.device]]`, *optional*):
  475. A map that specifies where each submodule should go. It doesn't need to be refined to each parameter/buffer
  476. name, once a given module name is inside, every submodule of it will be sent to the same device.
  477. To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For more
  478. information about each option see [here](../concept_guides/big_model_inference#designing-a-device-map).
  479. Defaults to None, which means [`dispatch_model`] will not be called.
  480. max_memory (`Dict`, *optional*):
  481. A dictionary device identifier to maximum memory. Will default to the maximum memory available for each GPU
  482. and the available CPU RAM if unset.
  483. no_split_module_classes (`List[str]`, *optional*):
  484. A list of layer class names that should never be split across device (for instance any layer that has a
  485. residual connection).
  486. offload_folder (`str` or `os.PathLike`, *optional*):
  487. If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
  488. offload_buffers (`bool`, *optional*, defaults to `False`):
  489. In the layers that are offloaded on the CPU or the hard drive, whether or not to offload the buffers as
  490. well as the parameters.
  491. dtype (`str` or `torch.dtype`, *optional*):
  492. If provided, the weights will be converted to that type when loaded.
  493. offload_state_dict (`bool`, *optional*):
  494. If `True`, will temporarily offload the CPU state dict on the hard drive to avoid getting out of CPU RAM if
  495. the weight of the CPU state dict + the biggest shard does not fit. Will default to `True` if the device map
  496. picked contains `"disk"` values.
  497. skip_keys (`str` or `List[str]`, *optional*):
  498. A list of keys to ignore when moving inputs or outputs between devices.
  499. preload_module_classes (`List[str]`, *optional*):
  500. A list of classes whose instances should load all their weights (even in the submodules) at the beginning
  501. of the forward. This should only be used for classes that have submodules which are registered but not
  502. called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
  503. `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
  504. force_hooks (`bool`, *optional*, defaults to `False`):
  505. Whether or not to force device hooks to be attached to the model even if all layers are dispatched to a
  506. single device.
  507. strict (`bool`, *optional*, defaults to `False`):
  508. Whether to strictly enforce that the keys in the checkpoint state_dict match the keys of the model's
  509. state_dict.
  510. full_state_dict (`bool`, *optional*, defaults to `True`): if this is set to `True`, all the tensors in the
  511. loaded state_dict will be gathered. No ShardedTensor and DTensor will be in the loaded state_dict.
  512. broadcast_from_rank0 (`False`, *optional*, defaults to `False`): when the option is `True`, a distributed
  513. `ProcessGroup` must be initialized. rank0 should receive a full state_dict and will broadcast the tensors
  514. in the state_dict one by one to other ranks. Other ranks will receive the tensors and shard (if applicable)
  515. according to the local shards in the model.
  516. Example:
  517. ```python
  518. >>> from accelerate import init_empty_weights, load_checkpoint_and_dispatch
  519. >>> from huggingface_hub import hf_hub_download
  520. >>> from transformers import AutoConfig, AutoModelForCausalLM
  521. >>> # Download the Weights
  522. >>> checkpoint = "EleutherAI/gpt-j-6B"
  523. >>> weights_location = hf_hub_download(checkpoint, "pytorch_model.bin")
  524. >>> # Create a model and initialize it with empty weights
  525. >>> config = AutoConfig.from_pretrained(checkpoint)
  526. >>> with init_empty_weights():
  527. ... model = AutoModelForCausalLM.from_config(config)
  528. >>> # Load the checkpoint and dispatch it to the right devices
  529. >>> model = load_checkpoint_and_dispatch(
  530. ... model, weights_location, device_map="auto", no_split_module_classes=["GPTJBlock"]
  531. ... )
  532. ```
  533. """
  534. if isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
  535. raise ValueError(
  536. "If passing a string for `device_map`, please choose 'auto', 'balanced', 'balanced_low_0' or 'sequential'."
  537. )
  538. if isinstance(device_map, str):
  539. if device_map != "sequential":
  540. max_memory = get_balanced_memory(
  541. model,
  542. max_memory=max_memory,
  543. no_split_module_classes=no_split_module_classes,
  544. dtype=dtype,
  545. low_zero=(device_map == "balanced_low_0"),
  546. )
  547. device_map = infer_auto_device_map(
  548. model,
  549. max_memory=max_memory,
  550. no_split_module_classes=no_split_module_classes,
  551. dtype=dtype,
  552. offload_buffers=offload_buffers,
  553. )
  554. if offload_state_dict is None and device_map is not None and "disk" in device_map.values():
  555. offload_state_dict = True
  556. load_checkpoint_in_model(
  557. model,
  558. checkpoint,
  559. device_map=device_map,
  560. offload_folder=offload_folder,
  561. dtype=dtype,
  562. offload_state_dict=offload_state_dict,
  563. offload_buffers=offload_buffers,
  564. strict=strict,
  565. full_state_dict=full_state_dict,
  566. broadcast_from_rank0=broadcast_from_rank0,
  567. )
  568. if device_map is None:
  569. return model
  570. return dispatch_model(
  571. model,
  572. device_map=device_map,
  573. offload_dir=offload_folder,
  574. offload_buffers=offload_buffers,
  575. skip_keys=skip_keys,
  576. preload_module_classes=preload_module_classes,
  577. force_hooks=force_hooks,
  578. )
  579. def attach_layerwise_casting_hooks(
  580. module: torch.nn.Module,
  581. storage_dtype: torch.dtype,
  582. compute_dtype: torch.dtype,
  583. skip_modules_pattern: Optional[Union[str, tuple[str, ...]]] = None,
  584. skip_modules_classes: Optional[tuple[type[torch.nn.Module], ...]] = None,
  585. non_blocking: bool = False,
  586. ) -> None:
  587. r"""
  588. Applies layerwise casting to a given module. The module expected here is a PyTorch `nn.Module`. This is helpful for
  589. reducing memory requirements when one doesn't want to fully quantize a model. Model params can be kept in say,
  590. `torch.float8_e4m3fn` and upcasted to a higher precision like `torch.bfloat16` during forward pass and downcasted
  591. back to `torch.float8_e4m3fn` to realize memory savings.
  592. Args:
  593. module (`torch.nn.Module`):
  594. The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
  595. precision dtype for storage.
  596. storage_dtype (`torch.dtype`):
  597. The dtype to cast the module to before/after the forward pass for storage.
  598. compute_dtype (`torch.dtype`):
  599. The dtype to cast the module to during the forward pass for computation.
  600. skip_modules_pattern (`tuple[str, ...]`, defaults to `None`):
  601. A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
  602. to `None` alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the
  603. module instead of its internal submodules.
  604. skip_modules_classes (`tuple[type[torch.nn.Module], ...]`, defaults to `None`):
  605. A list of module classes to skip during the layerwise casting process.
  606. non_blocking (`bool`, defaults to `False`):
  607. If `True`, the weight casting operations are non-blocking.
  608. Example:
  609. ```python
  610. >>> from accelerate.hooks import attach_layerwise_casting_hooks
  611. >>> from transformers import AutoModelForCausalLM
  612. >>> import torch
  613. >>> # Model
  614. >>> checkpoint = "EleutherAI/gpt-j-6B"
  615. >>> model = AutoModelForCausalLM.from_pretrained(checkpoint)
  616. >>> # Attach hooks and perform inference
  617. >>> attach_layerwise_casting_hooks(model, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16)
  618. >>> with torch.no_grad():
  619. ... model(...)
  620. ```
  621. Users can also pass modules they want to avoid from getting downcasted.
  622. ```py
  623. >>> attach_layerwise_casting_hooks(
  624. ... model, storage_dtype=torch.float8_e4m3fn, compute_dtype=torch.bfloat16, skip_modules_pattern=["norm"]
  625. ... )
  626. ```
  627. """
  628. _attach_layerwise_casting_hooks(
  629. module, storage_dtype, compute_dtype, skip_modules_pattern, skip_modules_classes, non_blocking
  630. )
  631. def _attach_layerwise_casting_hooks(
  632. module: torch.nn.Module,
  633. storage_dtype: torch.dtype,
  634. compute_dtype: torch.dtype,
  635. skip_modules_pattern: Optional[Union[str, tuple[str, ...]]] = None,
  636. skip_modules_classes: Optional[tuple[type[torch.nn.Module], ...]] = None,
  637. non_blocking: bool = False,
  638. _prefix: str = "",
  639. ):
  640. should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
  641. skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
  642. )
  643. if should_skip:
  644. logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
  645. return
  646. if isinstance(module, SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING):
  647. logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
  648. add_hook_to_module(
  649. module,
  650. LayerwiseCastingHook(storage_dtype=storage_dtype, compute_dtype=compute_dtype, non_blocking=non_blocking),
  651. append=True,
  652. )
  653. return
  654. for name, submodule in module.named_children():
  655. layer_name = f"{_prefix}.{name}" if _prefix else name
  656. _attach_layerwise_casting_hooks(
  657. submodule,
  658. storage_dtype,
  659. compute_dtype,
  660. skip_modules_pattern,
  661. skip_modules_classes,
  662. non_blocking,
  663. _prefix=layer_name,
  664. )
  665. def _attach_context_parallel_hooks(
  666. model: nn.Module,
  667. ):
  668. """
  669. Monkeypatch huggingface's `transformers` model to fix attention mask issues when using context parallelism.
  670. This function attaches forward_pre_hooks to each self_attn module of the model, where each hook checks the
  671. args/kwargs, if they contain an attention mask, if it does, it will remove this mask, check if it is a causal mask,
  672. if yes, will add a kwarg `is_causal=True`, otherwise will raise an error. This is because context parallelism does
  673. not support attention masks. This function modifies the model in place.
  674. Args:
  675. model (`nn.Module`):
  676. The model to attach the hooks to.
  677. """
  678. def _self_attn_pre_forward_hook(_module, module_args, module_kwargs):
  679. if "attention_mask" in module_kwargs:
  680. module_kwargs["attention_mask"] = None
  681. module_kwargs["is_causal"] = True
  682. return module_args, module_kwargs
  683. for name, module in model.named_modules():
  684. # We hope (assume) that if user uses their own model (without this structure which transformers uses), they read the docs saying they can't pass in attention masks
  685. # Then these cases can happen:
  686. # 1) some modules end with a `self-attn` module, in which case we attach the hook, but the
  687. # there's no attention mask kwarg -> hook is a no-op
  688. # 2) some modules end with a `self-attn` module, in which case we attach the hook, and the
  689. # attention mask kwarg is passed -> hook will remove the attention mask and add
  690. # `is_causal=True` kwarg, which either crashes the training or fixes it
  691. # (training would crash anyway as attention mask isn't supported)
  692. # 3) no modules end with a `self-attn` module, in which case we don't attach the hook, this is
  693. # a no-op as well
  694. if name.endswith("self_attn"):
  695. # we want the hook to be executed first, to avoid any other hooks doing work on the attention mask
  696. module.register_forward_pre_hook(_self_attn_pre_forward_hook, with_kwargs=True, prepend=True)