accelerate.py 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Since, https://github.com/huggingface/transformers/pull/36963, loading is always performed with models on meta
  16. device. But since the `init_empty_weights` and `find_tied_parameters` functions are from accelerate, and accelerate is
  17. somewhat still a soft dependency, we copy the functions here to be used natively in Transformers.
  18. The `init_empty_weights` and `init_on_device` functions were copied from `accelerate.big_modeling.py`, and the
  19. `find_tied_parameters` was copied from `accelerate.utils.modeling.py`
  20. """
  21. from contextlib import contextmanager
  22. from ..utils import is_torch_available, logging
  23. if is_torch_available():
  24. import torch
  25. import torch.nn as nn
  26. logger = logging.get_logger(__name__)
  27. @contextmanager
  28. def init_empty_weights(include_buffers: bool = False):
  29. """
  30. A context manager under which models are initialized with all parameters on the meta device, therefore creating an
  31. empty model. Useful when just initializing the model would blow the available RAM.
  32. Args:
  33. include_buffers (`bool`, *optional*):
  34. Whether or not to also put all buffers on the meta device while initializing.
  35. Example:
  36. ```python
  37. import torch.nn as nn
  38. from accelerate import init_empty_weights
  39. # Initialize a model with 100 billions parameters in no time and without using any RAM.
  40. with init_empty_weights():
  41. tst = nn.Sequential(*[nn.Linear(10000, 10000) for _ in range(1000)])
  42. ```
  43. <Tip warning={true}>
  44. Any model created under this context manager has no weights. As such you can't do something like
  45. `model.to(some_device)` with it. To load weights inside your empty model, see [`load_checkpoint_and_dispatch`].
  46. Make sure to overwrite the default device_map param for [`load_checkpoint_and_dispatch`], otherwise dispatch is not
  47. called.
  48. </Tip>
  49. """
  50. with init_on_device(torch.device("meta"), include_buffers=include_buffers) as f:
  51. yield f
  52. @contextmanager
  53. def init_on_device(device: "torch.device", include_buffers: bool = False):
  54. """
  55. A context manager under which models are initialized with all parameters on the specified device.
  56. Args:
  57. device (`torch.device`):
  58. Device to initialize all parameters on.
  59. include_buffers (`bool`, *optional*):
  60. Whether or not to also put all buffers on the meta device while initializing.
  61. Example:
  62. ```python
  63. import torch.nn as nn
  64. from accelerate import init_on_device
  65. with init_on_device(device=torch.device("cuda")):
  66. tst = nn.Linear(100, 100) # on `cuda` device
  67. ```
  68. """
  69. if include_buffers:
  70. with device:
  71. yield
  72. return
  73. old_register_parameter = nn.Module.register_parameter
  74. if include_buffers:
  75. old_register_buffer = nn.Module.register_buffer
  76. def register_empty_parameter(module, name, param):
  77. old_register_parameter(module, name, param)
  78. if param is not None:
  79. param_cls = type(module._parameters[name])
  80. kwargs = module._parameters[name].__dict__
  81. kwargs["requires_grad"] = param.requires_grad
  82. module._parameters[name] = param_cls(module._parameters[name].to(device), **kwargs)
  83. def register_empty_buffer(module, name, buffer, persistent=True):
  84. old_register_buffer(module, name, buffer, persistent=persistent)
  85. if buffer is not None:
  86. module._buffers[name] = module._buffers[name].to(device)
  87. # Patch tensor creation
  88. if include_buffers:
  89. tensor_constructors_to_patch = {
  90. torch_function_name: getattr(torch, torch_function_name)
  91. for torch_function_name in ["empty", "zeros", "ones", "full"]
  92. }
  93. else:
  94. tensor_constructors_to_patch = {}
  95. def patch_tensor_constructor(fn):
  96. def wrapper(*args, **kwargs):
  97. kwargs["device"] = device
  98. return fn(*args, **kwargs)
  99. return wrapper
  100. try:
  101. nn.Module.register_parameter = register_empty_parameter
  102. if include_buffers:
  103. nn.Module.register_buffer = register_empty_buffer
  104. for torch_function_name in tensor_constructors_to_patch:
  105. setattr(torch, torch_function_name, patch_tensor_constructor(getattr(torch, torch_function_name)))
  106. yield
  107. finally:
  108. nn.Module.register_parameter = old_register_parameter
  109. if include_buffers:
  110. nn.Module.register_buffer = old_register_buffer
  111. for torch_function_name, old_torch_function in tensor_constructors_to_patch.items():
  112. setattr(torch, torch_function_name, old_torch_function)
  113. def find_tied_parameters(model: "nn.Module", **kwargs):
  114. """
  115. Find the tied parameters in a given model.
  116. <Tip warning={true}>
  117. The signature accepts keyword arguments, but they are for the recursive part of this function and you should ignore
  118. them.
  119. </Tip>
  120. Args:
  121. model (`torch.nn.Module`): The model to inspect.
  122. Returns:
  123. list[list[str]]: A list of lists of parameter names being all tied together.
  124. Example:
  125. ```py
  126. >>> from collections import OrderedDict
  127. >>> import torch.nn as nn
  128. >>> model = nn.Sequential(OrderedDict([("linear1", nn.Linear(4, 4)), ("linear2", nn.Linear(4, 4))]))
  129. >>> model.linear2.weight = model.linear1.weight
  130. >>> find_tied_parameters(model)
  131. [['linear1.weight', 'linear2.weight']]
  132. ```
  133. """
  134. # get ALL model parameters and their names
  135. all_named_parameters = dict(model.named_parameters(remove_duplicate=False))
  136. # get ONLY unique named parameters,
  137. # if parameter is tied and have multiple names, it will be included only once
  138. no_duplicate_named_parameters = dict(model.named_parameters(remove_duplicate=True))
  139. # the difference of the two sets will give us the tied parameters
  140. tied_param_names = set(all_named_parameters.keys()) - set(no_duplicate_named_parameters.keys())
  141. # 'tied_param_names' contains the names of parameters that are tied in the model, but we do not know
  142. # which names refer to the same parameter. To identify this, we need to group them together.
  143. tied_param_groups = {}
  144. for tied_param_name in tied_param_names:
  145. tied_param = all_named_parameters[tied_param_name]
  146. for param_name, param in no_duplicate_named_parameters.items():
  147. # compare if parameters are the same, if so, group their names together
  148. if param is tied_param:
  149. if param_name not in tied_param_groups:
  150. tied_param_groups[param_name] = []
  151. tied_param_groups[param_name].append(tied_param_name)
  152. return [sorted([weight] + list(set(tied))) for weight, tied in tied_param_groups.items()]