quantizer_eetq.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Optional
  15. from .base import HfQuantizer
  16. if TYPE_CHECKING:
  17. from ..modeling_utils import PreTrainedModel
  18. from ..utils import is_accelerate_available, is_eetq_available, is_torch_available, logging
  19. from .quantizers_utils import get_module_from_name
  20. if is_torch_available():
  21. import torch
  22. logger = logging.get_logger(__name__)
  23. class EetqHfQuantizer(HfQuantizer):
  24. """
  25. 8-bit quantization from EETQ quantization method:
  26. before loading: converts transformer layers into W8A16Linear during loading: load 16bit weight and pass to the
  27. layer object after: quantizes individual weights in Linear8bitLt into 8bit at first .cuda() call
  28. """
  29. requires_parameters_quantization = True
  30. requires_calibration = False
  31. required_packages = ["eetq", "accelerate"]
  32. def __init__(self, quantization_config, **kwargs):
  33. super().__init__(quantization_config, **kwargs)
  34. self.quantization_config = quantization_config
  35. def validate_environment(self, *args, **kwargs):
  36. if not is_eetq_available():
  37. raise ImportError(
  38. "Using `eetq` 8-bit quantization requires eetq."
  39. "Please install the latest version of eetq from : https://github.com/NetEase-FuXi/EETQ"
  40. )
  41. try:
  42. import eetq # noqa: F401
  43. except ImportError as exc:
  44. if "shard_checkpoint" in str(exc):
  45. # EETQ 1.0.0 is currently broken with the latest transformers because it tries to import the removed
  46. # shard_checkpoint function, see https://github.com/NetEase-FuXi/EETQ/issues/34.
  47. # TODO: Update message once eetq releases a fix
  48. raise ImportError(
  49. "You are using a version of EETQ that is incompatible with the current transformers version. "
  50. "Either downgrade transformers to <= v4.46.3 or, if available, upgrade EETQ to > v1.0.0."
  51. ) from exc
  52. else:
  53. raise
  54. if not is_accelerate_available():
  55. raise ImportError("Loading an EETQ quantized model requires accelerate (`pip install accelerate`)")
  56. if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
  57. raise ValueError(
  58. "Converting into 8-bit weights from tf/flax weights is currently not supported, please make"
  59. " sure the weights are in PyTorch format."
  60. )
  61. if not torch.cuda.is_available():
  62. raise RuntimeError("No GPU found. A GPU is needed for quantization.")
  63. device_map = kwargs.get("device_map")
  64. if device_map is None:
  65. logger.warning_once(
  66. "You have loaded an EETQ model on CPU and have a CUDA device available, make sure to set "
  67. "your model on a GPU device in order to run your model."
  68. )
  69. elif device_map is not None:
  70. if isinstance(device_map, dict) and ("cpu" in device_map.values() or "disk" in device_map.values()):
  71. raise ValueError(
  72. "You are attempting to load an EETQ model with a device_map that contains a CPU or disk device."
  73. " This is not supported. Please remove the CPU or disk device from the device_map."
  74. )
  75. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  76. if dtype is None:
  77. dtype = torch.float16
  78. logger.info(
  79. "Overriding dtype=%s with `dtype=torch.float16` due to "
  80. "requirements of `eetq` to enable model loading in 8-bit. "
  81. "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
  82. " dtype=torch.float16 to remove this warning.",
  83. dtype,
  84. )
  85. elif dtype != torch.float16:
  86. logger.info("We suggest you to set `dtype=torch.float16` for better efficiency with EETQ.")
  87. return dtype
  88. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  89. from eetq import EetqLinear
  90. module, tensor_name = get_module_from_name(model, param_name)
  91. if isinstance(module, EetqLinear):
  92. if self.pre_quantized or tensor_name == "bias":
  93. return False
  94. else:
  95. return True
  96. return False
  97. def create_quantized_param(
  98. self,
  99. model: "PreTrainedModel",
  100. param_value: "torch.Tensor",
  101. param_name: str,
  102. target_device: "torch.device",
  103. **kwargs,
  104. ):
  105. from eetq import EetqLinear, quantize_and_preprocess_weights
  106. module, tensor_name = get_module_from_name(model, param_name)
  107. new_value, weight_scale = quantize_and_preprocess_weights(param_value)
  108. # Samity check
  109. if isinstance(module, EetqLinear):
  110. if self.pre_quantized or tensor_name == "bias":
  111. if tensor_name == "weight" and param_value.dtype != torch.int8:
  112. raise ValueError("Expect quantized weights but got an unquantized weight")
  113. else:
  114. if tensor_name == "weight_scale":
  115. raise ValueError("Expect unquantized weights but got a quantized weight_scale")
  116. module._buffers[tensor_name] = new_value.to(target_device)
  117. module.register("weight_scales", weight_scale.to(target_device))
  118. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  119. return model
  120. def _process_model_before_weight_loading(
  121. self,
  122. model: "PreTrainedModel",
  123. keep_in_fp32_modules: Optional[list[str]] = None,
  124. **kwargs,
  125. ):
  126. from ..integrations import replace_with_eetq_linear
  127. self.modules_to_not_convert = self.get_modules_to_not_convert(
  128. model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
  129. )
  130. model = replace_with_eetq_linear(
  131. model,
  132. modules_to_not_convert=self.modules_to_not_convert,
  133. quantization_config=self.quantization_config,
  134. pre_quantized=self.pre_quantized,
  135. )
  136. model.config.quantization_config = self.quantization_config
  137. def is_serializable(self, safe_serialization=None):
  138. return True
  139. @property
  140. def is_trainable(self) -> bool:
  141. return True