other.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import platform
  16. import re
  17. import socket
  18. from codecs import encode
  19. from collections import OrderedDict
  20. from functools import partial, reduce
  21. from types import MethodType
  22. from typing import Optional
  23. import numpy as np
  24. import torch
  25. from packaging.version import Version
  26. from safetensors.torch import save_file as safe_save_file
  27. from ..commands.config.default import write_basic_config # noqa: F401
  28. from ..logging import get_logger
  29. from ..state import PartialState
  30. from .constants import FSDP_PYTORCH_VERSION
  31. from .dataclasses import DistributedType
  32. from .imports import (
  33. is_deepspeed_available,
  34. is_numpy_available,
  35. is_torch_distributed_available,
  36. is_torch_xla_available,
  37. is_weights_only_available,
  38. )
  39. from .modeling import id_tensor_storage
  40. from .transformer_engine import convert_model
  41. from .versions import is_torch_version
  42. logger = get_logger(__name__)
  43. if is_torch_xla_available():
  44. import torch_xla.core.xla_model as xm
  45. def is_compiled_module(module: torch.nn.Module) -> bool:
  46. """
  47. Check whether the module was compiled with torch.compile()
  48. """
  49. if not hasattr(torch, "_dynamo"):
  50. return False
  51. return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
  52. def has_compiled_regions(module: torch.nn.Module) -> bool:
  53. """
  54. Check whether the module has submodules that were compiled with `torch.compile()`.
  55. """
  56. if not hasattr(torch, "_dynamo"):
  57. return False
  58. if module._modules:
  59. for submodule in module.modules():
  60. if isinstance(submodule, torch._dynamo.eval_frame.OptimizedModule):
  61. return True
  62. return False
  63. def is_repeated_blocks(module: torch.nn.Module) -> bool:
  64. """
  65. Check whether the module is a repeated block, i.e. `torch.nn.ModuleList` with all children of the same class. This
  66. is useful to determine whether we should apply regional compilation to the module.
  67. """
  68. return isinstance(module, torch.nn.ModuleList) and all(isinstance(m, module[0].__class__) for m in module)
  69. def has_repeated_blocks(module: torch.nn.Module) -> bool:
  70. """
  71. Check whether the module has repeated blocks, i.e. `torch.nn.ModuleList` with all children of the same class, at
  72. any level of the module hierarchy. This is useful to determine whether we should apply regional compilation to the
  73. module.
  74. """
  75. if module._modules:
  76. for submodule in module.modules():
  77. if is_repeated_blocks(submodule):
  78. return True
  79. return False
  80. def compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
  81. """
  82. Performs regional compilation where we target repeated blocks of the same class and compile them sequentially to
  83. hit the compiler's cache. For example, in `GPT2LMHeadModel`, the repeated block/class is `GPT2Block`, and can be
  84. accessed as `model.transformer.h[0]`. The rest of the model (e.g. model.lm_head) is compiled separately.
  85. This allows us to speed up the compilation overhead / cold start of models like LLMs and Transformers in general.
  86. See https://pytorch.org/tutorials/recipes/regional_compilation.html for more details.
  87. Args:
  88. module (`torch.nn.Module`):
  89. The model to compile.
  90. **compile_kwargs:
  91. Additional keyword arguments to pass to `torch.compile()`.
  92. Returns:
  93. `torch.nn.Module`: A new instance of the model with some compiled regions.
  94. Example:
  95. ```python
  96. >>> from accelerate.utils import compile_regions
  97. >>> from transformers import AutoModelForCausalLM
  98. >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
  99. >>> compiled_model = compile_regions(model, mode="reduce-overhead")
  100. >>> compiled_model.transformer.h[0]
  101. OptimizedModule(
  102. (_orig_mod): GPT2Block(
  103. (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  104. (attn): GPT2Attention(
  105. (c_attn): Conv1D(nf=2304, nx=768)
  106. (c_proj): Conv1D(nf=768, nx=768)
  107. (attn_dropout): Dropout(p=0.1, inplace=False)
  108. (resid_dropout): Dropout(p=0.1, inplace=False)
  109. )
  110. (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  111. (mlp): GPT2MLP(
  112. (c_fc): Conv1D(nf=3072, nx=768)
  113. (c_proj): Conv1D(nf=768, nx=3072)
  114. (act): NewGELUActivation()
  115. (dropout): Dropout(p=0.1, inplace=False)
  116. )
  117. )
  118. )
  119. ```
  120. """
  121. def _compile_regions(module: torch.nn.Module, **compile_kwargs) -> torch.nn.Module:
  122. if is_repeated_blocks(module):
  123. new_module = torch.nn.ModuleList()
  124. for submodule in module:
  125. new_module.append(torch.compile(submodule, **compile_kwargs))
  126. elif has_repeated_blocks(module):
  127. new_module = module.__class__.__new__(module.__class__)
  128. new_module.__dict__.update(module.__dict__)
  129. new_module._modules = {}
  130. for name, submodule in module.named_children():
  131. new_module.add_module(name, _compile_regions(submodule, **compile_kwargs))
  132. else:
  133. new_module = torch.compile(module, **compile_kwargs)
  134. return new_module
  135. new_module = _compile_regions(module, **compile_kwargs)
  136. if "_orig_mod" not in new_module.__dict__:
  137. # Keeps a reference to the original module to decompile/unwrap it later
  138. new_module.__dict__["_orig_mod"] = module
  139. return new_module
  140. def compile_regions_deepspeed(module: torch.nn.Module, **compile_kwargs):
  141. """
  142. Performs regional compilation the same way as `compile_regions`, but specifically for `DeepSpeedEngine.module`.
  143. Since the model is wrapped in a `DeepSpeedEngine` and has many added hooks, offloaded parameters, etc that
  144. `torch.compile(...)` interferes with, version of trgional compilation uses the inplace `module.compile()` method
  145. instead.
  146. Args:
  147. module (`torch.nn.Module`):
  148. The model to compile.
  149. **compile_kwargs:
  150. Additional keyword arguments to pass to `module.compile()`.
  151. """
  152. if is_repeated_blocks(module):
  153. for submodule in module:
  154. submodule.compile(**compile_kwargs)
  155. elif has_repeated_blocks(module):
  156. for child in module.children():
  157. compile_regions_deepspeed(child, **compile_kwargs)
  158. else: # leaf node
  159. module.compile(**compile_kwargs)
  160. def model_has_dtensor(model: torch.nn.Module) -> bool:
  161. """
  162. Check if the model has DTensor parameters.
  163. Args:
  164. model (`torch.nn.Module`):
  165. The model to check.
  166. Returns:
  167. `bool`: Whether the model has DTensor parameters.
  168. """
  169. if is_torch_version(">=", "2.5.0"):
  170. from torch.distributed.tensor import DTensor
  171. else:
  172. # from torch 2.0.0 (oldest supported accelerate torch version), DTensor is in torch.distributed._tensor
  173. from torch.distributed._tensor import DTensor
  174. return any(isinstance(p, DTensor) for p in model.parameters())
  175. def extract_model_from_parallel(
  176. model, keep_fp32_wrapper: bool = True, keep_torch_compile: bool = True, recursive: bool = False
  177. ):
  178. """
  179. Extract a model from its distributed containers.
  180. Args:
  181. model (`torch.nn.Module`):
  182. The model to extract.
  183. keep_fp32_wrapper (`bool`, *optional*):
  184. Whether to remove mixed precision hooks from the model.
  185. keep_torch_compile (`bool`, *optional*):
  186. Whether to unwrap compiled model.
  187. recursive (`bool`, *optional*, defaults to `False`):
  188. Whether to recursively extract all cases of `module.module` from `model` as well as unwrap child sublayers
  189. recursively, not just the top-level distributed containers.
  190. Returns:
  191. `torch.nn.Module`: The extracted model.
  192. """
  193. options = (torch.nn.parallel.DistributedDataParallel, torch.nn.DataParallel)
  194. is_compiled = is_compiled_module(model)
  195. has_compiled = has_compiled_regions(model)
  196. if is_compiled:
  197. compiled_model = model
  198. model = model._orig_mod
  199. elif has_compiled:
  200. compiled_model = model
  201. model = model.__dict__["_orig_mod"]
  202. if is_deepspeed_available():
  203. from deepspeed import DeepSpeedEngine
  204. options += (DeepSpeedEngine,)
  205. if is_torch_version(">=", FSDP_PYTORCH_VERSION) and is_torch_distributed_available():
  206. from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
  207. options += (FSDP,)
  208. while isinstance(model, options):
  209. model = model.module
  210. if recursive:
  211. # This is needed in cases such as using FSDPv2 on XLA
  212. def _recursive_unwrap(module):
  213. # Wrapped modules are standardly wrapped as `module`, similar to the cases earlier
  214. # with DDP, DataParallel, DeepSpeed, and FSDP
  215. if hasattr(module, "module"):
  216. unwrapped_module = _recursive_unwrap(module.module)
  217. else:
  218. unwrapped_module = module
  219. # Next unwrap child sublayers recursively
  220. for name, child in unwrapped_module.named_children():
  221. setattr(unwrapped_module, name, _recursive_unwrap(child))
  222. return unwrapped_module
  223. # Start with top-level
  224. model = _recursive_unwrap(model)
  225. if not keep_fp32_wrapper:
  226. forward = model.forward
  227. original_forward = model.__dict__.pop("_original_forward", None)
  228. if original_forward is not None:
  229. while hasattr(forward, "__wrapped__"):
  230. forward = forward.__wrapped__
  231. if forward == original_forward:
  232. break
  233. model.forward = MethodType(forward, model)
  234. if getattr(model, "_converted_to_transformer_engine", False):
  235. convert_model(model, to_transformer_engine=False)
  236. if keep_torch_compile:
  237. if is_compiled:
  238. compiled_model._orig_mod = model
  239. model = compiled_model
  240. elif has_compiled:
  241. compiled_model.__dict__["_orig_mod"] = model
  242. model = compiled_model
  243. return model
  244. def wait_for_everyone():
  245. """
  246. Introduces a blocking point in the script, making sure all processes have reached this point before continuing.
  247. <Tip warning={true}>
  248. Make sure all processes will reach this instruction otherwise one of your processes will hang forever.
  249. </Tip>
  250. """
  251. PartialState().wait_for_everyone()
  252. def clean_state_dict_for_safetensors(state_dict: dict):
  253. """
  254. Cleans the state dictionary from a model and removes tensor aliasing if present.
  255. Args:
  256. state_dict (`dict`):
  257. The state dictionary from a model
  258. """
  259. ptrs = collections.defaultdict(list)
  260. # When bnb serialization is used, weights in state dict can be strings
  261. for name, tensor in state_dict.items():
  262. if not isinstance(tensor, str):
  263. ptrs[id_tensor_storage(tensor)].append(name)
  264. # These are all pointers of tensors with shared memory
  265. shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
  266. warn_names = set()
  267. for names in shared_ptrs.values():
  268. # When not all duplicates have been cleaned, we still remove those keys but put a clear warning.
  269. # If the link between tensors was done at runtime then `from_pretrained` will not get
  270. # the key back leading to random tensor. A proper warning will be shown
  271. # during reload (if applicable), but since the file is not necessarily compatible with
  272. # the config, better show a proper warning.
  273. found_names = [name for name in names if name in state_dict]
  274. warn_names.update(found_names[1:])
  275. for name in found_names[1:]:
  276. del state_dict[name]
  277. if len(warn_names) > 0:
  278. logger.warning(
  279. f"Removed shared tensor {warn_names} while saving. This should be OK, but check by verifying that you don't receive any warning while reloading",
  280. )
  281. state_dict = {k: v.contiguous() if isinstance(v, torch.Tensor) else v for k, v in state_dict.items()}
  282. return state_dict
  283. def save(obj, f, save_on_each_node: bool = False, safe_serialization: bool = False):
  284. """
  285. Save the data to disk. Use in place of `torch.save()`.
  286. Args:
  287. obj:
  288. The data to save
  289. f:
  290. The file (or file-like object) to use to save the data
  291. save_on_each_node (`bool`, *optional*, defaults to `False`):
  292. Whether to only save on the global main process
  293. safe_serialization (`bool`, *optional*, defaults to `False`):
  294. Whether to save `obj` using `safetensors` or the traditional PyTorch way (that uses `pickle`).
  295. """
  296. # When TorchXLA is enabled, it's necessary to transfer all data to the CPU before saving.
  297. # Another issue arises with `id_tensor_storage`, which treats all XLA tensors as identical.
  298. # If tensors remain on XLA, calling `clean_state_dict_for_safetensors` will result in only
  299. # one XLA tensor remaining.
  300. if PartialState().distributed_type == DistributedType.XLA:
  301. obj = xm._maybe_convert_to_cpu(obj)
  302. # Check if it's a model and remove duplicates
  303. if safe_serialization:
  304. save_func = partial(safe_save_file, metadata={"format": "pt"})
  305. if isinstance(obj, OrderedDict):
  306. obj = clean_state_dict_for_safetensors(obj)
  307. else:
  308. save_func = torch.save
  309. if PartialState().is_main_process and not save_on_each_node:
  310. save_func(obj, f)
  311. elif PartialState().is_local_main_process and save_on_each_node:
  312. save_func(obj, f)
  313. # The following are considered "safe" globals to reconstruct various types of objects when using `weights_only=True`
  314. # These should be added and then removed after loading in the file
  315. np_core = np._core if is_numpy_available("2.0.0") else np.core
  316. TORCH_SAFE_GLOBALS = [
  317. # numpy arrays are just numbers, not objects, so we can reconstruct them safely
  318. np_core.multiarray._reconstruct,
  319. np.ndarray,
  320. # The following are needed for the RNG states
  321. encode,
  322. np.dtype,
  323. ]
  324. if is_numpy_available("1.25.0"):
  325. TORCH_SAFE_GLOBALS.append(np.dtypes.UInt32DType)
  326. def load(f, map_location=None, **kwargs):
  327. """
  328. Compatible drop-in replacement of `torch.load()` which allows for `weights_only` to be used if `torch` version is
  329. 2.4.0 or higher. Otherwise will ignore the kwarg.
  330. Will also add (and then remove) an exception for numpy arrays
  331. Args:
  332. f:
  333. The file (or file-like object) to use to load the data
  334. map_location:
  335. a function, `torch.device`, string or a dict specifying how to remap storage locations
  336. **kwargs:
  337. Additional keyword arguments to pass to `torch.load()`.
  338. """
  339. try:
  340. if is_weights_only_available():
  341. old_safe_globals = torch.serialization.get_safe_globals()
  342. if "weights_only" not in kwargs:
  343. kwargs["weights_only"] = True
  344. torch.serialization.add_safe_globals(TORCH_SAFE_GLOBALS)
  345. else:
  346. kwargs.pop("weights_only", None)
  347. loaded_obj = torch.load(f, map_location=map_location, **kwargs)
  348. finally:
  349. if is_weights_only_available():
  350. torch.serialization.clear_safe_globals()
  351. if old_safe_globals:
  352. torch.serialization.add_safe_globals(old_safe_globals)
  353. return loaded_obj
  354. def get_pretty_name(obj):
  355. """
  356. Gets a pretty name from `obj`.
  357. """
  358. if not hasattr(obj, "__qualname__") and not hasattr(obj, "__name__"):
  359. obj = getattr(obj, "__class__", obj)
  360. if hasattr(obj, "__qualname__"):
  361. return obj.__qualname__
  362. if hasattr(obj, "__name__"):
  363. return obj.__name__
  364. return str(obj)
  365. def merge_dicts(source, destination):
  366. """
  367. Recursively merges two dictionaries.
  368. Args:
  369. source (`dict`): The dictionary to merge into `destination`.
  370. destination (`dict`): The dictionary to merge `source` into.
  371. """
  372. for key, value in source.items():
  373. if isinstance(value, dict):
  374. node = destination.setdefault(key, {})
  375. merge_dicts(value, node)
  376. else:
  377. destination[key] = value
  378. return destination
  379. def is_port_in_use(port: Optional[int] = None) -> bool:
  380. """
  381. Checks if a port is in use on `localhost`. Useful for checking if multiple `accelerate launch` commands have been
  382. run and need to see if the port is already in use.
  383. """
  384. if port is None:
  385. port = 29500
  386. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  387. return s.connect_ex(("localhost", port)) == 0
  388. def get_free_port() -> int:
  389. """
  390. Gets a free port on `localhost`. Useful for automatic port selection when port 0 is specified in distributed
  391. training scenarios.
  392. Returns:
  393. int: An available port number
  394. """
  395. with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
  396. s.bind(("", 0)) # bind to port 0 for OS to assign a free port
  397. return s.getsockname()[1]
  398. def convert_bytes(size):
  399. "Converts `size` from bytes to the largest possible unit"
  400. for x in ["bytes", "KB", "MB", "GB", "TB"]:
  401. if size < 1024.0:
  402. return f"{round(size, 2)} {x}"
  403. size /= 1024.0
  404. return f"{round(size, 2)} PB"
  405. def check_os_kernel():
  406. """Warns if the kernel version is below the recommended minimum on Linux."""
  407. # see issue #1929
  408. info = platform.uname()
  409. system = info.system
  410. if system != "Linux":
  411. return
  412. _, version, *_ = re.split(r"(\d+\.\d+\.\d+)", info.release)
  413. min_version = "5.5.0"
  414. if Version(version) < Version(min_version):
  415. msg = (
  416. f"Detected kernel version {version}, which is below the recommended minimum of {min_version}; this can "
  417. "cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher."
  418. )
  419. logger.warning(msg, main_process_only=True)
  420. def recursive_getattr(obj, attr: str):
  421. """
  422. Recursive `getattr`.
  423. Args:
  424. obj:
  425. A class instance holding the attribute.
  426. attr (`str`):
  427. The attribute that is to be retrieved, e.g. 'attribute1.attribute2'.
  428. """
  429. def _getattr(obj, attr):
  430. return getattr(obj, attr)
  431. return reduce(_getattr, [obj] + attr.split("."))
  432. def get_module_children_bottom_up(model: torch.nn.Module, return_fqns: bool = False) -> list[torch.nn.Module]:
  433. """Traverse the model in bottom-up order and return the children modules in that order.
  434. Args:
  435. model (`torch.nn.Module`): the model to get the children of
  436. Returns:
  437. `list[torch.nn.Module]`: a list of children modules of `model` in bottom-up order. The last element is the
  438. `model` itself.
  439. """
  440. top = model if not return_fqns else ("", model)
  441. stack = [top]
  442. ordered_modules = []
  443. while stack:
  444. current_module = stack.pop()
  445. if return_fqns:
  446. current_module_name, current_module = current_module
  447. for name, attr in current_module.named_children():
  448. if isinstance(attr, torch.nn.Module):
  449. if return_fqns:
  450. child_name = current_module_name + "." + name if current_module_name else name
  451. stack.append((child_name, attr))
  452. else:
  453. stack.append(attr)
  454. if return_fqns:
  455. ordered_modules.append((current_module_name, current_module))
  456. else:
  457. ordered_modules.append(current_module)
  458. return ordered_modules[::-1]