quantizer_vptq.py 3.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Optional
  15. from .base import HfQuantizer
  16. if TYPE_CHECKING:
  17. from ..modeling_utils import PreTrainedModel
  18. from ..utils import is_accelerate_available, is_torch_available, is_vptq_available, logging
  19. from ..utils.quantization_config import QuantizationConfigMixin
  20. if is_torch_available():
  21. import torch
  22. logger = logging.get_logger(__name__)
  23. class VptqHfQuantizer(HfQuantizer):
  24. """
  25. Quantizer of the VPTQ method. Enables the loading of prequantized models.
  26. """
  27. requires_calibration = True
  28. required_packages = ["vptq"]
  29. def __init__(self, quantization_config: QuantizationConfigMixin, **kwargs):
  30. super().__init__(quantization_config, **kwargs)
  31. self.quantization_config = quantization_config
  32. def validate_environment(self, *args, **kwargs):
  33. if not is_accelerate_available():
  34. raise ImportError("Using `vptq` quantization requires Accelerate: `pip install accelerate`")
  35. if not is_vptq_available():
  36. raise ImportError("Using `vptq` quantization requires VPTQ>=0.0.4: `pip install -U vptq`")
  37. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  38. if dtype is None:
  39. if torch.cuda.is_available():
  40. dtype = torch.float16
  41. logger.info(
  42. "CUDA available. Assuming VPTQ inference on GPU and loading the model in `torch.float16`. To overwrite it, set `dtype` manually."
  43. )
  44. else:
  45. import vptq
  46. device_availability = getattr(vptq, "device_availability", lambda device: False)
  47. if device_availability("cpu") is True:
  48. raise RuntimeError("No GPU found. Please wait for the next release of VPTQ to use CPU inference")
  49. dtype = torch.float32
  50. logger.info("No GPU found. Assuming VPTQ inference on CPU and loading the model in `torch.float32`.")
  51. return dtype
  52. def _process_model_before_weight_loading(
  53. self,
  54. model: "PreTrainedModel",
  55. keep_in_fp32_modules: Optional[list[str]] = None,
  56. **kwargs,
  57. ):
  58. """
  59. we don't have param like modules_to_not_convert to indicate which layers should not be quantized
  60. because `quantization_config` include the layers that should be quantized
  61. """
  62. from ..integrations import replace_with_vptq_linear
  63. self.modules_to_not_convert = self.get_modules_to_not_convert(
  64. model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
  65. )
  66. replace_with_vptq_linear(
  67. model,
  68. quantization_config=self.quantization_config,
  69. modules_to_not_convert=self.modules_to_not_convert,
  70. )
  71. model.config.quantization_config = self.quantization_config
  72. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  73. return model
  74. @property
  75. def is_trainable(self) -> bool:
  76. return False
  77. def is_serializable(self, safe_serialization=None):
  78. return True