ao.py 4.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. Needed utilities for torchao FP8 training.
  16. """
  17. from functools import partial
  18. from typing import TYPE_CHECKING, Callable, Optional
  19. import torch
  20. from .imports import is_torchao_available, torchao_required
  21. if TYPE_CHECKING:
  22. if is_torchao_available():
  23. from torchao.float8.float8_linear import Float8LinearConfig
  24. def find_first_last_linear_layers(model: torch.nn.Module):
  25. """
  26. Finds the first and last linear layer names in a model.
  27. This is needed during FP8 to avoid issues with instability by keeping the first and last layers unquantized.
  28. Ref: https://x.com/xariusrke/status/1826669142604141052
  29. """
  30. first_linear, last_linear = None, None
  31. for name, module in model.named_modules():
  32. if isinstance(module, torch.nn.Linear):
  33. if first_linear is None:
  34. first_linear = name
  35. last_linear = name
  36. return first_linear, last_linear
  37. def filter_linear_layers(module, fqn: str, layers_to_filter: list[str]) -> bool:
  38. """
  39. A function which will check if `module` is:
  40. - a `torch.nn.Linear` layer
  41. - has in_features and out_features divisible by 16
  42. - is not part of `layers_to_filter`
  43. Args:
  44. module (`torch.nn.Module`):
  45. The module to check.
  46. fqn (`str`):
  47. The fully qualified name of the layer.
  48. layers_to_filter (`List[str]`):
  49. The list of layers to filter.
  50. """
  51. if isinstance(module, torch.nn.Linear):
  52. if module.in_features % 16 != 0 or module.out_features % 16 != 0:
  53. return False
  54. if fqn in layers_to_filter:
  55. return False
  56. return True
  57. def filter_first_and_last_linear_layers(module, fqn: str) -> bool:
  58. """
  59. A filter function which will filter out all linear layers except the first and last.
  60. <Tip>
  61. For stability reasons, we skip the first and last linear layers Otherwise can lead to the model not training or
  62. converging properly
  63. </Tip>
  64. Args:
  65. module (`torch.nn.Module`):
  66. The module to check.
  67. fqn (`str`):
  68. The fully qualified name of the layer.
  69. """
  70. first_linear, last_linear = find_first_last_linear_layers(module)
  71. return filter_linear_layers(module, fqn, layers_to_filter=[first_linear, last_linear])
  72. @torchao_required
  73. def has_ao_layers(model: torch.nn.Module):
  74. from torchao.float8.float8_linear import Float8Linear
  75. for name, module in model.named_modules():
  76. if isinstance(module, Float8Linear):
  77. return True
  78. return False
  79. @torchao_required
  80. def convert_model_to_fp8_ao(
  81. model: torch.nn.Module,
  82. config: Optional["Float8LinearConfig"] = None,
  83. module_filter_func: Optional[Callable] = filter_first_and_last_linear_layers,
  84. ):
  85. """
  86. Converts all `nn.Linear` layers in the model (except the first and last) to torchao's `Float8Linear` layer inplace.
  87. Args:
  88. model (`torch.nn.Module`):
  89. The model to convert.
  90. config (`torchao.float8.Float8LinearConfig`, *optional*):
  91. The configuration for the FP8 training. Recommended to utilize
  92. `torchao.float8.recipe_name_to_linear_config` to generate this. In general, the default config should be
  93. sufficient (what is passed when set to `None`).
  94. module_filter_func (`Callable`, *optional*, defaults to `filter_linear_layers`):
  95. Optional function that must take in a module and layer name, and returns a boolean indicating whether the
  96. module should be converted to FP8. Defaults to `filter_linear_layers`. See it for an example.
  97. Example:
  98. ```python
  99. from accelerate.utils.ao import convert_model_to_fp8_ao
  100. model = MyModel()
  101. model.to("cuda")
  102. convert_to_float8_training(model)
  103. model.train()
  104. ```
  105. """
  106. from torchao.float8 import convert_to_float8_training
  107. first_linear, last_linear = find_first_last_linear_layers(model)
  108. if module_filter_func is None:
  109. module_filter_func = partial(filter_linear_layers, layers_to_filter=[first_linear, last_linear])
  110. convert_to_float8_training(model, module_filter_fn=module_filter_func, config=config)