| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140 |
- # Copyright 2025 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """
- Needed utilities for torchao FP8 training.
- """
- from functools import partial
- from typing import TYPE_CHECKING, Callable, Optional
- import torch
- from .imports import is_torchao_available, torchao_required
- if TYPE_CHECKING:
- if is_torchao_available():
- from torchao.float8.float8_linear import Float8LinearConfig
- def find_first_last_linear_layers(model: torch.nn.Module):
- """
- Finds the first and last linear layer names in a model.
- This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.
- Ref: https://x.com/xariusrke/status/1826669142604141052
- """
- first_linear, last_linear = None, None
- for name, module in model.named_modules():
- if isinstance(module, torch.nn.Linear):
- if first_linear is None:
- first_linear = name
- last_linear = name
- return first_linear, last_linear
- def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool:
- """
- A function which will check if `module` is:
- - a `torch.nn.Linear` layer
- - has in_features and out_features divisible by 16
- - is not part of `layers_to_filter`
- Args:
- module (`torch.nn.Module`):
- The module to check.
- fqn (`str`):
- The fully qualified name of the layer.
- layers_to_filter (`List[str]`):
- The list of layers to filter.
- """
- if isinstance(module, torch.nn.Linear):
- if module.in_features % 16 != 0 or module.out_features % 16 != 0:
- return False
- if fqn in layers_to_filter:
- return False
- return True
- def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
- """
- A filter function which will filter out all linear layers except the first and last.
- <Tip>
- For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or
- converging properly
- </Tip>
- Args:
- module (`torch.nn.Module`):
- The module to check.
- fqn (`str`):
- The fully qualified name of the layer.
- """
- first_linear, last_linear = find_first_last_linear_layers(module)
- return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])
- @torchao_required
- def has_ao_layers(model: torch.nn.Module):
- from torchao.float8.float8_linear import Float8Linear
- for name, module in model.named_modules():
- if isinstance(module, Float8Linear):
- return True
- return False
- @torchao_required
- def convert_model_to_fp8_ao(
- model: torch.nn.Module,
- config: Optional["Float8LinearConfig"] = None,
- module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,
- ):
- """
- Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.
- Args:
- model (`torch.nn.Module`):
- The model to convert.
- config (`torchao.float8.Float8LinearConfig`, *optional*):
- The configuration for the FP8 training. Recommended to utilize
- `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
- sufficient (what is passed when set to `None`).
- module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):
- Optional function that must take in a module and layer name, and returns a boolean indicating whether the
- module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.
- Example:
- ```python
- from accelerate.utils.ao import convert_model_to_fp8_ao
- model = MyModel()
- model.to("cuda")
- convert_to_float8_training(model)
- model.train()
- ```
- """
- from torchao.float8 import convert_to_float8_training
- first_linear, last_linear = find_first_last_linear_layers(model)
- if module_filter_func is None:
- module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear])
- convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)
|