| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316 |
- #!/usr/bin/env python
- # Copyright 2023 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import Optional
- import torch
- from huggingface_hub import model_info
- from huggingface_hub.utils import GatedRepoError, RepositoryNotFoundError
- from accelerate import init_empty_weights
- from accelerate.commands.utils import CustomArgumentParser
- from accelerate.utils import (
- calculate_maximum_sizes,
- convert_bytes,
- is_timm_available,
- is_transformers_available,
- )
- if is_transformers_available():
- import transformers
- from transformers import AutoConfig, AutoModel
- if is_timm_available():
- import timm
- def verify_on_hub(repo: str, token: Optional[str] = None):
- "Verifies that the model is on the hub and returns the model info."
- try:
- return model_info(repo, token=token)
- except (OSError, GatedRepoError):
- return "gated"
- except RepositoryNotFoundError:
- return "repo"
- def check_has_model(error):
- """
- Checks what library spawned `error` when a model is not found
- """
- if is_timm_available() and isinstance(error, RuntimeError) and "Unknown model" in error.args[0]:
- return "timm"
- elif (
- is_transformers_available()
- and isinstance(error, OSError)
- and "does not appear to have a file named" in error.args[0]
- ):
- return "transformers"
- else:
- return "unknown"
- def create_empty_model(
- model_name: str, library_name: str, trust_remote_code: bool = False, access_token: Optional[str] = None
- ):
- """
- Creates an empty model in full precision from its parent library on the `Hub` to calculate the overall memory
- consumption.
- Args:
- model_name (`str`):
- The model name on the Hub
- library_name (`str`):
- The library the model has an integration with, such as `transformers`. Will be used if `model_name` has no
- metadata on the Hub to determine the library.
- trust_remote_code (`bool`, `optional`, defaults to `False`):
- Whether or not to allow for custom models defined on the Hub in their own modeling files. This option
- should only be set to `True` for repositories you trust and in which you have read the code, as it will
- execute code present on the Hub on your local machine.
- access_token (`str`, `optional`, defaults to `None`):
- The access token to use to access private or gated models on the Hub. (for use on the Gradio app)
- Returns:
- `torch.nn.Module`: The torch model that has been initialized on the `meta` device.
- """
- model_info = verify_on_hub(model_name, access_token)
- # Simplified errors
- if model_info == "gated":
- raise GatedRepoError(
- f"Repo for model `{model_name}` is gated. You must be authenticated to access it. Please run `huggingface-cli login`."
- )
- elif model_info == "repo":
- raise RepositoryNotFoundError(
- f"Repo for model `{model_name}` does not exist on the Hub. If you are trying to access a private repo,"
- " make sure you are authenticated via `huggingface-cli login` and have access."
- )
- if library_name is None:
- library_name = getattr(model_info, "library_name", False)
- if not library_name:
- raise ValueError(
- f"Model `{model_name}` does not have any library metadata on the Hub, please manually pass in a `--library_name` to use (such as `transformers`)"
- )
- if library_name == "transformers":
- if not is_transformers_available():
- raise ImportError(
- f"To check `{model_name}`, `transformers` must be installed. Please install it via `pip install transformers`"
- )
- print(f"Loading pretrained config for `{model_name}` from `transformers`...")
- if model_info.config is None:
- raise RuntimeError(f"Tried to load `{model_name}` with `transformers` but it does not have any metadata.")
- auto_map = model_info.config.get("auto_map", False)
- config = AutoConfig.from_pretrained(model_name, trust_remote_code=trust_remote_code, token=access_token)
- with init_empty_weights():
- # remote code could specify a specific `AutoModel` class in the `auto_map`
- constructor = AutoModel
- if isinstance(auto_map, dict):
- value = None
- for key in auto_map.keys():
- if key.startswith("AutoModelFor"):
- value = key
- break
- if value is not None:
- constructor = getattr(transformers, value)
- # we need to pass the dtype, otherwise it is going to use the torch_dtype that is saved in the config
- model = constructor.from_config(config, torch_dtype=torch.float32, trust_remote_code=trust_remote_code)
- elif library_name == "timm":
- if not is_timm_available():
- raise ImportError(
- f"To check `{model_name}`, `timm` must be installed. Please install it via `pip install timm`"
- )
- print(f"Loading pretrained config for `{model_name}` from `timm`...")
- with init_empty_weights():
- model = timm.create_model(model_name, pretrained=False)
- else:
- raise ValueError(
- f"Library `{library_name}` is not supported yet, please open an issue on GitHub for us to add support."
- )
- return model
- def create_ascii_table(headers: list, rows: list, title: str):
- "Creates a pretty table from a list of rows, minimal version of `tabulate`."
- sep_char, in_between = "│", "─"
- column_widths = []
- for i in range(len(headers)):
- column_values = [row[i] for row in rows] + [headers[i]]
- max_column_width = max(len(value) for value in column_values)
- column_widths.append(max_column_width)
- formats = [f"%{column_widths[i]}s" for i in range(len(rows[0]))]
- pattern = f"{sep_char}{sep_char.join(formats)}{sep_char}"
- diff = 0
- def make_row(left_char, middle_char, right_char):
- return f"{left_char}{middle_char.join([in_between * n for n in column_widths])}{in_between * diff}{right_char}"
- separator = make_row("├", "┼", "┤")
- if len(title) > sum(column_widths):
- diff = abs(len(title) - len(separator))
- column_widths[-1] += diff
- # Update with diff
- separator = make_row("├", "┼", "┤")
- initial_rows = [
- make_row("┌", in_between, "┐"),
- f"{sep_char}{title.center(len(separator) - 2)}{sep_char}",
- make_row("├", "┬", "┤"),
- ]
- table = "\n".join(initial_rows) + "\n"
- column_widths[-1] += diff
- centered_line = [text.center(column_widths[i]) for i, text in enumerate(headers)]
- table += f"{pattern % tuple(centered_line)}\n{separator}\n"
- for i, line in enumerate(rows):
- centered_line = [t.center(column_widths[i]) for i, t in enumerate(line)]
- table += f"{pattern % tuple(centered_line)}\n"
- table += f"└{'┴'.join([in_between * n for n in column_widths])}┘"
- return table
- def estimate_command_parser(subparsers=None):
- if subparsers is not None:
- parser = subparsers.add_parser("estimate-memory")
- else:
- parser = CustomArgumentParser(description="Model size estimator for fitting a model onto CUDA memory.")
- parser.add_argument("model_name", type=str, help="The model name on the Hugging Face Hub.")
- parser.add_argument(
- "--library_name",
- type=str,
- help="The library the model has an integration with, such as `transformers`, needed only if this information is not stored on the Hub.",
- choices=["timm", "transformers"],
- )
- parser.add_argument(
- "--dtypes",
- type=str,
- nargs="+",
- default=["float32", "float16", "int8", "int4"],
- help="The dtypes to use for the model, must be one (or many) of `float32`, `float16`, `int8`, and `int4`",
- choices=["float32", "float16", "int8", "int4"],
- )
- parser.add_argument(
- "--trust_remote_code",
- action="store_true",
- help="""Whether or not to allow for custom models defined on the Hub in their own modeling files. This flag
- should only be used for repositories you trust and in which you have read the code, as it will execute
- code present on the Hub on your local machine.""",
- default=False,
- )
- if subparsers is not None:
- parser.set_defaults(func=estimate_command)
- return parser
- def estimate_training_usage(bytes: int, mixed_precision: str, msamp_config: Optional[str] = None) -> dict:
- """
- Given an amount of `bytes` and `mixed_precision`, calculates how much training memory is needed for a batch size of
- 1.
- Args:
- bytes (`int`):
- The size of the model being trained.
- mixed_precision (`str`):
- The mixed precision that would be ran.
- msamp_config (`str`):
- The msamp config to estimate the training memory for if `mixed_precision` is set to `"fp8"`.
- """
- memory_sizes = {"model": -1, "optimizer": -1, "gradients": -1, "step": -1}
- fp32_size = bytes
- fp16_size = bytes // 2
- if mixed_precision == "float32":
- memory_sizes["model"] = fp32_size
- memory_sizes["gradients"] = fp32_size
- memory_sizes["optimizer"] = fp32_size * 2
- memory_sizes["step"] = fp32_size * 4
- elif mixed_precision in ("float16", "bfloat16") or (mixed_precision == "fp8" and msamp_config is None):
- # With native `TransformersEngine`, there is no memory savings with FP8
- # With mixed precision training, the model has weights stored
- # in FP16 and FP32
- memory_sizes["model"] = fp32_size
- # 1.5 from weight gradient + computation (GEMM)
- memory_sizes["gradients"] = fp32_size + fp16_size
- # 2x from optimizer states
- memory_sizes["optimizer"] = fp32_size * 2 # Optimizer states
- memory_sizes["step"] = memory_sizes["optimizer"]
- return memory_sizes
- def gather_data(args):
- "Creates an empty model and gathers the data for the sizes"
- try:
- model = create_empty_model(
- args.model_name, library_name=args.library_name, trust_remote_code=args.trust_remote_code
- )
- except (RuntimeError, OSError) as e:
- library = check_has_model(e)
- if library != "unknown":
- raise RuntimeError(
- f"Tried to load `{args.model_name}` with `{library}` but a possible model to load was not found inside the repo."
- )
- raise e
- total_size, largest_layer = calculate_maximum_sizes(model)
- data = []
- for dtype in args.dtypes:
- dtype_total_size = total_size
- dtype_largest_layer = largest_layer[0]
- dtype_training_size = estimate_training_usage(dtype_total_size, dtype)
- if dtype == "float16":
- dtype_total_size /= 2
- dtype_largest_layer /= 2
- elif dtype == "int8":
- dtype_total_size /= 4
- dtype_largest_layer /= 4
- elif dtype == "int4":
- dtype_total_size /= 8
- dtype_largest_layer /= 8
- data.append([dtype, dtype_largest_layer, dtype_total_size, dtype_training_size])
- return data
- def estimate_command(args):
- data = gather_data(args)
- for row in data:
- for i, item in enumerate(row):
- if isinstance(item, (int, float)):
- row[i] = convert_bytes(item)
- elif isinstance(item, dict):
- training_usage = max(item.values())
- row[i] = convert_bytes(training_usage) if training_usage != -1 else "N/A"
- headers = ["dtype", "Largest Layer", "Total Size", "Training using Adam"]
- title = f"Memory Usage for loading `{args.model_name}`"
- table = create_ascii_table(headers, data, title)
- print(table)
- def main():
- parser = estimate_command_parser()
- args = parser.parse_args()
- estimate_command(args)
- if __name__ == "__main__":
- main()
|