hooks.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. from collections.abc import Mapping
  16. from typing import Optional, Union
  17. import torch
  18. import torch.nn as nn
  19. from .state import PartialState
  20. from .utils import (
  21. PrefixedDataset,
  22. find_device,
  23. named_module_tensors,
  24. send_to_device,
  25. set_module_tensor_to_device,
  26. )
  27. from .utils.imports import (
  28. is_mlu_available,
  29. is_musa_available,
  30. is_npu_available,
  31. )
  32. from .utils.memory import clear_device_cache
  33. from .utils.modeling import get_non_persistent_buffers
  34. from .utils.other import recursive_getattr
  35. _accelerate_added_attributes = ["to", "cuda", "npu", "xpu", "mlu", "sdaa", "musa"]
  36. class ModelHook:
  37. """
  38. A hook that contains callbacks to be executed just before and after the forward method of a model. The difference
  39. with PyTorch existing hooks is that they get passed along the kwargs.
  40. Class attribute:
  41. - **no_grad** (`bool`, *optional*, defaults to `False`) -- Whether or not to execute the actual forward pass under
  42. the `torch.no_grad()` context manager.
  43. """
  44. no_grad = False
  45. def init_hook(self, module):
  46. """
  47. To be executed when the hook is attached to the module.
  48. Args:
  49. module (`torch.nn.Module`): The module attached to this hook.
  50. """
  51. return module
  52. def pre_forward(self, module, *args, **kwargs):
  53. """
  54. To be executed just before the forward method of the model.
  55. Args:
  56. module (`torch.nn.Module`): The module whose forward pass will be executed just after this event.
  57. args (`Tuple[Any]`): The positional arguments passed to the module.
  58. kwargs (`Dict[Str, Any]`): The keyword arguments passed to the module.
  59. Returns:
  60. `Tuple[Tuple[Any], Dict[Str, Any]]`: A tuple with the treated `args` and `kwargs`.
  61. """
  62. return args, kwargs
  63. def post_forward(self, module, output):
  64. """
  65. To be executed just after the forward method of the model.
  66. Args:
  67. module (`torch.nn.Module`): The module whose forward pass been executed just before this event.
  68. output (`Any`): The output of the module.
  69. Returns:
  70. `Any`: The processed `output`.
  71. """
  72. return output
  73. def detach_hook(self, module):
  74. """
  75. To be executed when the hook is detached from a module.
  76. Args:
  77. module (`torch.nn.Module`): The module detached from this hook.
  78. """
  79. return module
  80. class SequentialHook(ModelHook):
  81. """
  82. A hook that can contain several hooks and iterates through them at each event.
  83. """
  84. def __init__(self, *hooks):
  85. self.hooks = hooks
  86. def init_hook(self, module):
  87. for hook in self.hooks:
  88. module = hook.init_hook(module)
  89. return module
  90. def pre_forward(self, module, *args, **kwargs):
  91. for hook in self.hooks:
  92. args, kwargs = hook.pre_forward(module, *args, **kwargs)
  93. return args, kwargs
  94. def post_forward(self, module, output):
  95. for hook in self.hooks:
  96. output = hook.post_forward(module, output)
  97. return output
  98. def detach_hook(self, module):
  99. for hook in self.hooks:
  100. module = hook.detach_hook(module)
  101. return module
  102. def add_hook_to_module(module: nn.Module, hook: ModelHook, append: bool = False):
  103. """
  104. Adds a hook to a given module. This will rewrite the `forward` method of the module to include the hook, to remove
  105. this behavior and restore the original `forward` method, use `remove_hook_from_module`.
  106. <Tip warning={true}>
  107. If the module already contains a hook, this will replace it with the new hook passed by default. To chain two hooks
  108. together, pass `append=True`, so it chains the current and new hook into an instance of the `SequentialHook` class.
  109. </Tip>
  110. Args:
  111. module (`torch.nn.Module`):
  112. The module to attach a hook to.
  113. hook (`ModelHook`):
  114. The hook to attach.
  115. append (`bool`, *optional*, defaults to `False`):
  116. Whether the hook should be chained with an existing one (if module already contains a hook) or not.
  117. Returns:
  118. `torch.nn.Module`: The same module, with the hook attached (the module is modified in place, so the result can
  119. be discarded).
  120. """
  121. if append and (getattr(module, "_hf_hook", None) is not None):
  122. old_hook = module._hf_hook
  123. remove_hook_from_module(module)
  124. hook = SequentialHook(old_hook, hook)
  125. if hasattr(module, "_hf_hook") and hasattr(module, "_old_forward"):
  126. # If we already put some hook on this module, we replace it with the new one.
  127. old_forward = module._old_forward
  128. else:
  129. old_forward = module.forward
  130. module._old_forward = old_forward
  131. module = hook.init_hook(module)
  132. module._hf_hook = hook
  133. def new_forward(module, *args, **kwargs):
  134. args, kwargs = module._hf_hook.pre_forward(module, *args, **kwargs)
  135. if module._hf_hook.no_grad:
  136. with torch.no_grad():
  137. output = module._old_forward(*args, **kwargs)
  138. else:
  139. output = module._old_forward(*args, **kwargs)
  140. return module._hf_hook.post_forward(module, output)
  141. # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
  142. # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
  143. if "GraphModuleImpl" in str(type(module)):
  144. module.__class__.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
  145. else:
  146. module.forward = functools.update_wrapper(functools.partial(new_forward, module), old_forward)
  147. return module
  148. def remove_hook_from_module(module: nn.Module, recurse=False):
  149. """
  150. Removes any hook attached to a module via `add_hook_to_module`.
  151. Args:
  152. module (`torch.nn.Module`): The module to attach a hook to.
  153. recurse (`bool`, **optional**): Whether to remove the hooks recursively
  154. Returns:
  155. `torch.nn.Module`: The same module, with the hook detached (the module is modified in place, so the result can
  156. be discarded).
  157. """
  158. if hasattr(module, "_hf_hook"):
  159. module._hf_hook.detach_hook(module)
  160. delattr(module, "_hf_hook")
  161. if hasattr(module, "_old_forward"):
  162. # Overriding a GraphModuleImpl forward freezes the forward call and later modifications on the graph will fail.
  163. # Reference: https://pytorch.slack.com/archives/C3PDTEV8E/p1705929610405409
  164. if "GraphModuleImpl" in str(type(module)):
  165. module.__class__.forward = module._old_forward
  166. else:
  167. module.forward = module._old_forward
  168. delattr(module, "_old_forward")
  169. # Remove accelerate added warning hooks from dispatch_model
  170. for attr in _accelerate_added_attributes:
  171. module.__dict__.pop(attr, None)
  172. if recurse:
  173. for child in module.children():
  174. remove_hook_from_module(child, recurse)
  175. return module
  176. class AlignDevicesHook(ModelHook):
  177. """
  178. A generic `ModelHook` that ensures inputs and model weights are on the same device for the forward pass of the
  179. associated module, potentially offloading the weights after the forward pass.
  180. Args:
  181. execution_device (`torch.device`, *optional*):
  182. The device on which inputs and model weights should be placed before the forward pass.
  183. offload (`bool`, *optional*, defaults to `False`):
  184. Whether or not the weights should be offloaded after the forward pass.
  185. io_same_device (`bool`, *optional*, defaults to `False`):
  186. Whether or not the output should be placed on the same device as the input was.
  187. weights_map (`Mapping[str, torch.Tensor]`, *optional*):
  188. When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
  189. offload_buffers (`bool`, *optional*, defaults to `False`):
  190. Whether or not to include the associated module's buffers when offloading.
  191. place_submodules (`bool`, *optional*, defaults to `False`):
  192. Whether to place the submodules on `execution_device` during the `init_hook` event.
  193. """
  194. def __init__(
  195. self,
  196. execution_device: Optional[Union[int, str, torch.device]] = None,
  197. offload: bool = False,
  198. io_same_device: bool = False,
  199. weights_map: Optional[Mapping] = None,
  200. offload_buffers: bool = False,
  201. place_submodules: bool = False,
  202. skip_keys: Optional[Union[str, list[str]]] = None,
  203. tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
  204. ):
  205. self.execution_device = execution_device
  206. self.offload = offload
  207. self.io_same_device = io_same_device
  208. self.weights_map = weights_map
  209. self.offload_buffers = offload_buffers
  210. self.place_submodules = place_submodules
  211. self.skip_keys = skip_keys
  212. # Will contain the input device when `io_same_device=True`.
  213. self.input_device = None
  214. self.param_original_devices = {}
  215. self.buffer_original_devices = {}
  216. self.tied_params_names = set()
  217. # The hook pre_forward/post_forward need to have knowledge of this dictionary, as with offloading we want to avoid duplicating memory
  218. # for tied weights already loaded on the target execution device.
  219. self.tied_params_map = tied_params_map
  220. def __repr__(self):
  221. return (
  222. f"AlignDevicesHook(execution_device={self.execution_device}, offload={self.offload}, "
  223. f"io_same_device={self.io_same_device}, offload_buffers={self.offload_buffers}, "
  224. f"place_submodules={self.place_submodules}, skip_keys={repr(self.skip_keys)})"
  225. )
  226. def init_hook(self, module):
  227. # In case the AlignDevicesHook is on meta device, ignore tied weights as data_ptr() is then always zero.
  228. if self.execution_device == "meta" or self.execution_device == torch.device("meta"):
  229. self.tied_params_map = None
  230. if not self.offload and self.execution_device is not None:
  231. for name, _ in named_module_tensors(module, recurse=self.place_submodules):
  232. set_module_tensor_to_device(module, name, self.execution_device, tied_params_map=self.tied_params_map)
  233. elif self.offload:
  234. self.original_devices = {
  235. name: param.device for name, param in named_module_tensors(module, recurse=self.place_submodules)
  236. }
  237. if self.weights_map is None:
  238. self.weights_map = {
  239. name: param.to("cpu")
  240. for name, param in named_module_tensors(
  241. module, include_buffers=self.offload_buffers, recurse=self.place_submodules
  242. )
  243. }
  244. for name, _ in named_module_tensors(
  245. module, include_buffers=self.offload_buffers, recurse=self.place_submodules, remove_non_persistent=True
  246. ):
  247. # When using disk offloading, we can not rely on `weights_map[name].data_ptr()` as the reference pointer,
  248. # as we have no guarantee that safetensors' `file.get_tensor()` will always give the same pointer.
  249. # As we have no reliable way to track the shared data pointer of tied weights in this case, we use tied_params_names: List[str]
  250. # to add on the fly pointers to `tied_params_map` in the pre_forward call.
  251. if (
  252. self.tied_params_map is not None
  253. and recursive_getattr(module, name).data_ptr() in self.tied_params_map
  254. ):
  255. self.tied_params_names.add(name)
  256. set_module_tensor_to_device(module, name, "meta")
  257. if not self.offload_buffers and self.execution_device is not None:
  258. for name, _ in module.named_buffers(recurse=self.place_submodules):
  259. set_module_tensor_to_device(
  260. module, name, self.execution_device, tied_params_map=self.tied_params_map
  261. )
  262. elif self.offload_buffers and self.execution_device is not None:
  263. for name in get_non_persistent_buffers(module, recurse=self.place_submodules):
  264. set_module_tensor_to_device(
  265. module, name, self.execution_device, tied_params_map=self.tied_params_map
  266. )
  267. return module
  268. def pre_forward(self, module, *args, **kwargs):
  269. if self.io_same_device:
  270. self.input_device = find_device([args, kwargs])
  271. if self.offload:
  272. self.tied_pointers_to_remove = set()
  273. for name, _ in named_module_tensors(
  274. module,
  275. include_buffers=self.offload_buffers,
  276. recurse=self.place_submodules,
  277. remove_non_persistent=True,
  278. ):
  279. fp16_statistics = None
  280. value = self.weights_map[name]
  281. if "weight" in name and name.replace("weight", "SCB") in self.weights_map.keys():
  282. if value.dtype == torch.int8:
  283. fp16_statistics = self.weights_map[name.replace("weight", "SCB")]
  284. # In case we are using offloading with tied weights, we need to keep track of the offloaded weights
  285. # that are loaded on device at this point, as we will need to remove them as well from the dictionary
  286. # self.tied_params_map in order to allow to free memory.
  287. if name in self.tied_params_names and value.data_ptr() not in self.tied_params_map:
  288. self.tied_params_map[value.data_ptr()] = {}
  289. if (
  290. value is not None
  291. and self.tied_params_map is not None
  292. and value.data_ptr() in self.tied_params_map
  293. and self.execution_device not in self.tied_params_map[value.data_ptr()]
  294. ):
  295. self.tied_pointers_to_remove.add((value.data_ptr(), self.execution_device))
  296. set_module_tensor_to_device(
  297. module,
  298. name,
  299. self.execution_device,
  300. value=value,
  301. fp16_statistics=fp16_statistics,
  302. tied_params_map=self.tied_params_map,
  303. )
  304. return send_to_device(args, self.execution_device), send_to_device(
  305. kwargs, self.execution_device, skip_keys=self.skip_keys
  306. )
  307. def post_forward(self, module, output):
  308. if self.offload:
  309. for name, _ in named_module_tensors(
  310. module,
  311. include_buffers=self.offload_buffers,
  312. recurse=self.place_submodules,
  313. remove_non_persistent=True,
  314. ):
  315. set_module_tensor_to_device(module, name, "meta")
  316. if type(module).__name__ == "Linear8bitLt":
  317. module.state.SCB = None
  318. module.state.CxB = None
  319. # We may have loaded tied weights into self.tied_params_map (avoiding to load them several times in e.g. submodules): remove them from
  320. # this dictionary to allow the garbage collector to do its job.
  321. for value_pointer, device in self.tied_pointers_to_remove:
  322. if isinstance(device, int):
  323. if is_npu_available():
  324. device = f"npu:{device}"
  325. elif is_mlu_available():
  326. device = f"mlu:{device}"
  327. elif is_musa_available():
  328. device = f"musa:{device}"
  329. if device in self.tied_params_map[value_pointer]:
  330. del self.tied_params_map[value_pointer][device]
  331. self.tied_pointers_to_remove = set()
  332. if self.io_same_device and self.input_device is not None:
  333. output = send_to_device(output, self.input_device, skip_keys=self.skip_keys)
  334. return output
  335. def detach_hook(self, module):
  336. if self.offload:
  337. for name, device in self.original_devices.items():
  338. if device != torch.device("meta"):
  339. set_module_tensor_to_device(module, name, device, value=self.weights_map.get(name, None))
  340. return module
  341. def attach_execution_device_hook(
  342. module: torch.nn.Module,
  343. execution_device: Union[int, str, torch.device],
  344. skip_keys: Optional[Union[str, list[str]]] = None,
  345. preload_module_classes: Optional[list[str]] = None,
  346. tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
  347. ):
  348. """
  349. Recursively attaches `AlignDevicesHook` to all submodules of a given model to make sure they have the right
  350. execution device
  351. Args:
  352. module (`torch.nn.Module`):
  353. The module where we want to attach the hooks.
  354. execution_device (`int`, `str` or `torch.device`):
  355. The device on which inputs and model weights should be placed before the forward pass.
  356. skip_keys (`str` or `List[str]`, *optional*):
  357. A list of keys to ignore when moving inputs or outputs between devices.
  358. preload_module_classes (`List[str]`, *optional*):
  359. A list of classes whose instances should load all their weights (even in the submodules) at the beginning
  360. of the forward. This should only be used for classes that have submodules which are registered but not
  361. called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
  362. `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
  363. tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
  364. A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
  365. device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
  366. instead of duplicating memory.
  367. """
  368. if not hasattr(module, "_hf_hook") and len(module.state_dict()) > 0:
  369. add_hook_to_module(
  370. module,
  371. AlignDevicesHook(execution_device, skip_keys=skip_keys, tied_params_map=tied_params_map),
  372. )
  373. # Break the recursion if we get to a preload module.
  374. if preload_module_classes is not None and module.__class__.__name__ in preload_module_classes:
  375. return
  376. for child in module.children():
  377. attach_execution_device_hook(
  378. child,
  379. execution_device,
  380. skip_keys=skip_keys,
  381. preload_module_classes=preload_module_classes,
  382. tied_params_map=tied_params_map,
  383. )
  384. def attach_align_device_hook(
  385. module: torch.nn.Module,
  386. execution_device: Optional[torch.device] = None,
  387. offload: bool = False,
  388. weights_map: Optional[Mapping] = None,
  389. offload_buffers: bool = False,
  390. module_name: str = "",
  391. skip_keys: Optional[Union[str, list[str]]] = None,
  392. preload_module_classes: Optional[list[str]] = None,
  393. tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
  394. ):
  395. """
  396. Recursively attaches `AlignDevicesHook` to all submodules of a given model that have direct parameters and/or
  397. buffers.
  398. Args:
  399. module (`torch.nn.Module`):
  400. The module where we want to attach the hooks.
  401. execution_device (`torch.device`, *optional*):
  402. The device on which inputs and model weights should be placed before the forward pass.
  403. offload (`bool`, *optional*, defaults to `False`):
  404. Whether or not the weights should be offloaded after the forward pass.
  405. weights_map (`Mapping[str, torch.Tensor]`, *optional*):
  406. When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
  407. offload_buffers (`bool`, *optional*, defaults to `False`):
  408. Whether or not to include the associated module's buffers when offloading.
  409. module_name (`str`, *optional*, defaults to `""`):
  410. The name of the module.
  411. skip_keys (`str` or `List[str]`, *optional*):
  412. A list of keys to ignore when moving inputs or outputs between devices.
  413. preload_module_classes (`List[str]`, *optional*):
  414. A list of classes whose instances should load all their weights (even in the submodules) at the beginning
  415. of the forward. This should only be used for classes that have submodules which are registered but not
  416. called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
  417. `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
  418. tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
  419. A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
  420. device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
  421. instead of duplicating memory.
  422. """
  423. # Attach the hook on this module if it has any direct tensor.
  424. directs = named_module_tensors(module)
  425. full_offload = (
  426. offload and preload_module_classes is not None and module.__class__.__name__ in preload_module_classes
  427. )
  428. if len(list(directs)) > 0 or full_offload:
  429. if weights_map is not None:
  430. prefix = f"{module_name}." if len(module_name) > 0 else ""
  431. prefixed_weights_map = PrefixedDataset(weights_map, prefix)
  432. else:
  433. prefixed_weights_map = None
  434. hook = AlignDevicesHook(
  435. execution_device=execution_device,
  436. offload=offload,
  437. weights_map=prefixed_weights_map,
  438. offload_buffers=offload_buffers,
  439. place_submodules=full_offload,
  440. skip_keys=skip_keys,
  441. tied_params_map=tied_params_map,
  442. )
  443. add_hook_to_module(module, hook, append=True)
  444. # We stop the recursion in case we hit the full offload.
  445. if full_offload:
  446. return
  447. # Recurse on all children of the module.
  448. for child_name, child in module.named_children():
  449. child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
  450. attach_align_device_hook(
  451. child,
  452. execution_device=execution_device,
  453. offload=offload,
  454. weights_map=weights_map,
  455. offload_buffers=offload_buffers,
  456. module_name=child_name,
  457. preload_module_classes=preload_module_classes,
  458. skip_keys=skip_keys,
  459. tied_params_map=tied_params_map,
  460. )
  461. def remove_hook_from_submodules(module: nn.Module):
  462. """
  463. Recursively removes all hooks attached on the submodules of a given model.
  464. Args:
  465. module (`torch.nn.Module`): The module on which to remove all hooks.
  466. """
  467. remove_hook_from_module(module)
  468. for child in module.children():
  469. remove_hook_from_submodules(child)
  470. def attach_align_device_hook_on_blocks(
  471. module: nn.Module,
  472. execution_device: Optional[Union[torch.device, dict[str, torch.device]]] = None,
  473. offload: Union[bool, dict[str, bool]] = False,
  474. weights_map: Optional[Mapping] = None,
  475. offload_buffers: bool = False,
  476. module_name: str = "",
  477. skip_keys: Optional[Union[str, list[str]]] = None,
  478. preload_module_classes: Optional[list[str]] = None,
  479. tied_params_map: Optional[dict[int, dict[torch.device, torch.Tensor]]] = None,
  480. ):
  481. """
  482. Attaches `AlignDevicesHook` to all blocks of a given model as needed.
  483. Args:
  484. module (`torch.nn.Module`):
  485. The module where we want to attach the hooks.
  486. execution_device (`torch.device` or `Dict[str, torch.device]`, *optional*):
  487. The device on which inputs and model weights should be placed before the forward pass. It can be one device
  488. for the whole module, or a dictionary mapping module name to device.
  489. offload (`bool`, *optional*, defaults to `False`):
  490. Whether or not the weights should be offloaded after the forward pass. It can be one boolean for the whole
  491. module, or a dictionary mapping module name to boolean.
  492. weights_map (`Mapping[str, torch.Tensor]`, *optional*):
  493. When the model weights are offloaded, a (potentially lazy) map from param names to the tensor values.
  494. offload_buffers (`bool`, *optional*, defaults to `False`):
  495. Whether or not to include the associated module's buffers when offloading.
  496. module_name (`str`, *optional*, defaults to `""`):
  497. The name of the module.
  498. skip_keys (`str` or `List[str]`, *optional*):
  499. A list of keys to ignore when moving inputs or outputs between devices.
  500. preload_module_classes (`List[str]`, *optional*):
  501. A list of classes whose instances should load all their weights (even in the submodules) at the beginning
  502. of the forward. This should only be used for classes that have submodules which are registered but not
  503. called directly during the forward, for instance if a `dense` linear layer is registered, but at forward,
  504. `dense.weight` and `dense.bias` are used in some operations instead of calling `dense` directly.
  505. tied_params_map (Optional[Dict[int, Dict[torch.device, torch.Tensor]]], *optional*, defaults to `None`):
  506. A map of data pointers to dictionaries of devices to already dispatched tied weights. For a given execution
  507. device, this parameter is useful to reuse the first available pointer of a shared weight for all others,
  508. instead of duplicating memory.
  509. """
  510. # If one device and one offload, we've got one hook.
  511. if not isinstance(execution_device, Mapping) and not isinstance(offload, dict):
  512. if not offload:
  513. hook = AlignDevicesHook(
  514. execution_device=execution_device,
  515. io_same_device=True,
  516. skip_keys=skip_keys,
  517. place_submodules=True,
  518. tied_params_map=tied_params_map,
  519. )
  520. add_hook_to_module(module, hook)
  521. else:
  522. attach_align_device_hook(
  523. module,
  524. execution_device=execution_device,
  525. offload=True,
  526. weights_map=weights_map,
  527. offload_buffers=offload_buffers,
  528. module_name=module_name,
  529. skip_keys=skip_keys,
  530. tied_params_map=tied_params_map,
  531. )
  532. return
  533. if not isinstance(execution_device, Mapping):
  534. execution_device = {key: execution_device for key in offload.keys()}
  535. if not isinstance(offload, Mapping):
  536. offload = {key: offload for key in execution_device.keys()}
  537. if module_name in execution_device and module_name in offload and not offload[module_name]:
  538. hook = AlignDevicesHook(
  539. execution_device=execution_device[module_name],
  540. offload_buffers=offload_buffers,
  541. io_same_device=(module_name == ""),
  542. place_submodules=True,
  543. skip_keys=skip_keys,
  544. tied_params_map=tied_params_map,
  545. )
  546. add_hook_to_module(module, hook)
  547. attach_execution_device_hook(
  548. module, execution_device[module_name], skip_keys=skip_keys, tied_params_map=tied_params_map
  549. )
  550. elif module_name in execution_device and module_name in offload:
  551. attach_align_device_hook(
  552. module,
  553. execution_device=execution_device[module_name],
  554. offload=True,
  555. weights_map=weights_map,
  556. offload_buffers=offload_buffers,
  557. module_name=module_name,
  558. skip_keys=skip_keys,
  559. preload_module_classes=preload_module_classes,
  560. tied_params_map=tied_params_map,
  561. )
  562. if not hasattr(module, "_hf_hook"):
  563. hook = AlignDevicesHook(
  564. execution_device=execution_device[module_name],
  565. io_same_device=(module_name == ""),
  566. skip_keys=skip_keys,
  567. tied_params_map=tied_params_map,
  568. )
  569. add_hook_to_module(module, hook)
  570. attach_execution_device_hook(
  571. module,
  572. execution_device[module_name],
  573. preload_module_classes=preload_module_classes,
  574. skip_keys=skip_keys,
  575. tied_params_map=tied_params_map,
  576. )
  577. elif module_name == "":
  578. hook = AlignDevicesHook(
  579. execution_device=execution_device.get(""),
  580. io_same_device=True,
  581. skip_keys=skip_keys,
  582. tied_params_map=tied_params_map,
  583. )
  584. add_hook_to_module(module, hook)
  585. for child_name, child in module.named_children():
  586. child_name = f"{module_name}.{child_name}" if len(module_name) > 0 else child_name
  587. attach_align_device_hook_on_blocks(
  588. child,
  589. execution_device=execution_device,
  590. offload=offload,
  591. weights_map=weights_map,
  592. offload_buffers=offload_buffers,
  593. module_name=child_name,
  594. preload_module_classes=preload_module_classes,
  595. skip_keys=skip_keys,
  596. tied_params_map=tied_params_map,
  597. )
  598. class CpuOffload(ModelHook):
  599. """
  600. Offloads a model on the CPU until its forward pass is called. The model will not be offloaded back to the CPU after
  601. the forward, the user needs to call the `init_hook` method again for this.
  602. Args:
  603. execution_device(`str`, `int` or `torch.device`, *optional*):
  604. The device on which the model should be executed. Will default to the MPS device if it's available, then
  605. GPU 0 if there is a GPU, and finally to the CPU.
  606. prev_module_hook (`UserCpuOffloadHook`, *optional*):
  607. The hook sent back by [`cpu_offload_with_hook`] for a previous model in the pipeline you are running. If
  608. passed, its offload method will be called just before the forward of the model to which this hook is
  609. attached.
  610. """
  611. def __init__(
  612. self,
  613. execution_device: Optional[Union[str, int, torch.device]] = None,
  614. prev_module_hook: Optional["UserCpuOffloadHook"] = None,
  615. ):
  616. self.prev_module_hook = prev_module_hook
  617. self.execution_device = execution_device if execution_device is not None else PartialState().default_device
  618. def init_hook(self, module):
  619. return module.to("cpu")
  620. def pre_forward(self, module, *args, **kwargs):
  621. if self.prev_module_hook is not None and isinstance(self.prev_module_hook, UserCpuOffloadHook):
  622. prev_module = self.prev_module_hook.model
  623. prev_device = next(prev_module.parameters()).device
  624. # Only offload the previous module if it is not already on CPU.
  625. if prev_device != torch.device("cpu"):
  626. self.prev_module_hook.offload()
  627. clear_device_cache()
  628. # If the current device is already the self.execution_device, we can skip the transfer.
  629. current_device = next(module.parameters()).device
  630. if current_device == self.execution_device:
  631. return args, kwargs
  632. module.to(self.execution_device)
  633. return send_to_device(args, self.execution_device), send_to_device(kwargs, self.execution_device)
  634. class UserCpuOffloadHook:
  635. """
  636. A simple hook grouping a model and a `ModelHook`, which provides easy APIs for to call the init method of the hook
  637. or remove it entirely.
  638. """
  639. def __init__(self, model, hook):
  640. self.model = model
  641. self.hook = hook
  642. def offload(self):
  643. self.hook.init_hook(self.model)
  644. def remove(self):
  645. remove_hook_from_module(self.model)
  646. class LayerwiseCastingHook(ModelHook):
  647. r"""
  648. A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
  649. for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
  650. footprint.
  651. """
  652. _is_stateful = False
  653. def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
  654. self.storage_dtype = storage_dtype
  655. self.compute_dtype = compute_dtype
  656. self.non_blocking = non_blocking
  657. def init_hook(self, module: torch.nn.Module):
  658. module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
  659. return module
  660. def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
  661. module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
  662. return args, kwargs
  663. def post_forward(self, module: torch.nn.Module, output):
  664. module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
  665. return output