eetq.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # coding=utf-8
  2. # Copyright 2024 NetEase, Inc. and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from ..utils import is_accelerate_available, is_eetq_available, logging
  16. if is_eetq_available():
  17. import eetq
  18. import torch.nn as nn
  19. if is_accelerate_available():
  20. from accelerate import init_empty_weights
  21. logger = logging.get_logger(__name__)
  22. def _replace_with_eetq_linear(
  23. model,
  24. modules_to_not_convert=None,
  25. current_key_name=None,
  26. quantization_config=None,
  27. has_been_replaced=False,
  28. pre_quantized=False,
  29. ):
  30. """
  31. Private method that wraps the recursion for module replacement.
  32. Returns the converted model and a boolean that indicates if the conversion has been successful or not.
  33. """
  34. if current_key_name is None:
  35. current_key_name = []
  36. for name, module in model.named_children():
  37. current_key_name.append(name)
  38. if (isinstance(module, nn.Linear)) and name not in modules_to_not_convert:
  39. # Check if the current key is not in the `modules_to_not_convert`
  40. current_key_name_str = ".".join(current_key_name)
  41. if not any(
  42. (key + "." in current_key_name_str) or (key == current_key_name_str) for key in modules_to_not_convert
  43. ):
  44. with init_empty_weights():
  45. in_features = module.in_features
  46. out_features = module.out_features
  47. model._modules[name] = eetq.EetqLinear(
  48. in_features, out_features, module.bias is not None, module.weight.device
  49. )
  50. if pre_quantized:
  51. model._modules[name].register_scale(module.weight.device)
  52. has_been_replaced = True
  53. # Force requires grad to False to avoid unexpected errors
  54. model._modules[name].requires_grad_(False)
  55. if len(list(module.children())) > 0:
  56. _, has_been_replaced = _replace_with_eetq_linear(
  57. module,
  58. modules_to_not_convert,
  59. current_key_name,
  60. quantization_config,
  61. has_been_replaced=has_been_replaced,
  62. pre_quantized=pre_quantized,
  63. )
  64. # Remove the last key for recursion
  65. current_key_name.pop(-1)
  66. return model, has_been_replaced
  67. def replace_with_eetq_linear(
  68. model, modules_to_not_convert=None, current_key_name=None, quantization_config=None, pre_quantized=False
  69. ):
  70. """
  71. A helper function to replace all `torch.nn.Linear` modules by `eetq.EetqLinear` modules from the `eetq`
  72. library. This will enable running your models using high performance int8 weight-only gemm kerner from
  73. FasterTransformer and TensorRT-LLM. Make sure `eetq` compiled with the correct CUDA
  74. version of your hardware is installed before running this function. EETQ shall be installed via the source
  75. 'https://github.com/NetEase-FuXi/EETQ'
  76. The function will be run recursively and replace all `torch.nn.Linear` modules except for the `lm_head` that should
  77. be kept as a `torch.nn.Linear` module. The replacement is done under `init_empty_weights` context manager so no
  78. CPU/GPU memory is required to run this function. Each weight will be quantized along the channel.
  79. Parameters:
  80. model (`torch.nn.Module`):
  81. Input model or `torch.nn.Module` as the function is run recursively.
  82. modules_to_not_convert (`list[`str`]`, *optional*, defaults to `["lm_head"]`):
  83. Names of the modules to not convert in `EetqLinear`. In practice we keep the `lm_head` in full precision
  84. for numerical stability reasons.
  85. current_key_name (`list[`str`]`, *optional*):
  86. An array to track the current key of the recursion. This is used to check whether the current key (part of
  87. it) is not in the list of modules to not convert (for instances modules that are offloaded to `cpu` or
  88. `disk`).
  89. """
  90. modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
  91. if quantization_config.modules_to_not_convert is not None:
  92. modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
  93. modules_to_not_convert = list(set(modules_to_not_convert))
  94. model, has_been_replaced = _replace_with_eetq_linear(
  95. model, modules_to_not_convert, current_key_name, quantization_config, pre_quantized=pre_quantized
  96. )
  97. if not has_been_replaced:
  98. logger.warning(
  99. "You are loading your model using eetq but no linear modules were found in your model."
  100. " Please double check your model architecture, or submit an issue on github if you think this is"
  101. " a bug."
  102. )
  103. return model