| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424 |
- # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from typing import TYPE_CHECKING, Optional
- from .base import HfQuantizer
- if TYPE_CHECKING:
- from ..modeling_utils import PreTrainedModel
- from ..utils import (
- is_accelerate_available,
- is_kernels_available,
- is_torch_available,
- is_triton_available,
- logging,
- )
- from .quantizers_utils import get_module_from_name
- if is_torch_available():
- import torch
- logger = logging.get_logger(__name__)
- triton_kernels_hub = None
- class Mxfp4HfQuantizer(HfQuantizer):
- """
- FP4 quantization using fbgemm kernels
- """
- requires_parameters_quantization = True
- requires_calibration = False
- required_packages = ["accelerate"]
- def __init__(self, quantization_config, **kwargs):
- super().__init__(quantization_config, **kwargs)
- self.quantization_config = quantization_config
- self.triton_kernels_hub = None
- def _lazy_import_kernels(self):
- """Lazy import and initialize kernels only when needed"""
- if self.triton_kernels_hub is None:
- try:
- from kernels import get_kernel
- self.triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
- except ImportError:
- raise ImportError("kernels package is required for MXFP4 quantization")
- return self.triton_kernels_hub
- def validate_environment(self, *args, **kwargs):
- if not is_torch_available():
- raise ImportError(
- "Using mxfp4 quantization requires torch"
- "Please install the latest version of torch ( pip install --upgrade torch )"
- )
- if self.quantization_config.dequantize:
- return
- if not (torch.cuda.is_available() or torch.xpu.is_available()):
- if self.pre_quantized:
- logger.warning_once(
- "Using MXFP4 quantized models requires a GPU, we will default to dequantizing the model to bf16"
- )
- self.quantization_config.dequantize = True
- return
- else:
- raise RuntimeError("Quantizing a model using MXFP4 requires a GPU")
- if not is_accelerate_available():
- raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
- if torch.xpu.is_available():
- gpu_is_supported = True
- kernels_available = is_triton_available("3.5.0") and is_kernels_available()
- else:
- compute_capability = torch.cuda.get_device_capability()
- gpu_is_supported = compute_capability >= (7, 5)
- kernels_available = is_triton_available("3.4.0") and is_kernels_available()
- if self.pre_quantized:
- # On unsupported GPUs or without kernels, we will dequantize the model to bf16
- if not gpu_is_supported:
- logger.warning_once(
- "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series) "
- "We will default to dequantizing the model to bf16."
- )
- self.quantization_config.dequantize = True
- return
- if not kernels_available:
- logger.warning_once(
- "MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16"
- )
- self.quantization_config.dequantize = True
- return
- elif not gpu_is_supported:
- # we can't quantize the model in this case so we raise an error
- raise ValueError(
- "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series) "
- )
- elif not kernels_available:
- # we can't quantize the model in this case so we raise an error
- raise ValueError(
- "MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0"
- )
- if not self.pre_quantized:
- self._lazy_import_kernels()
- device_map = kwargs.get("device_map")
- if device_map is None:
- logger.warning_once(
- "You have loaded an FP4 model on CPU and have a CUDA/XPU device available, make sure to set "
- "your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or device_map = 'xpu'. "
- )
- elif device_map is not None:
- if (
- not self.pre_quantized
- and isinstance(device_map, dict)
- and ("cpu" in device_map.values() or "disk" in device_map.values())
- ):
- raise ValueError(
- "You are attempting to load an FP4 model with a device_map that contains a CPU or disk device."
- "This is not supported when the model is quantized on the fly. "
- "Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
- )
- def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
- if dtype is None:
- dtype = torch.bfloat16
- logger.info(
- "Overriding dtype=%s with `dtype=torch.bfloat16` due to "
- "requirements of `fbgemm-gpu` to enable model loading in fp4. "
- "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
- " dtype=torch.bfloat16 to remove this warning.",
- dtype,
- )
- return dtype
- def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
- from ..integrations import Mxfp4GptOssExperts
- from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
- # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently
- if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name):
- module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")])
- else:
- module, tensor_name = get_module_from_name(model, param_name)
- if isinstance(module, Mxfp4GptOssExperts) or (
- isinstance(module, GptOssExperts) and self.quantization_config.dequantize
- ):
- if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
- return False
- return True
- return False
- def create_quantized_param(
- self,
- model: "PreTrainedModel",
- param_value: "torch.Tensor",
- param_name: str,
- target_device: "torch.device",
- **kwargs,
- ):
- from ..integrations import (
- Mxfp4GptOssExperts,
- dequantize,
- load_and_swizzle_mxfp4,
- quantize_to_mxfp4,
- swizzle_mxfp4,
- )
- from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
- if not self.pre_quantized:
- triton_kernels_hub = self._lazy_import_kernels()
- module, _ = get_module_from_name(model, param_name)
- with torch.device(target_device):
- if isinstance(module, Mxfp4GptOssExperts):
- triton_weight_tensor, weight_scale = quantize_to_mxfp4(param_value, triton_kernels_hub)
- PrecisionConfig, FlexCtx, InFlexData = (
- triton_kernels_hub.matmul_ogs.PrecisionConfig,
- triton_kernels_hub.matmul_ogs.FlexCtx,
- triton_kernels_hub.matmul_ogs.InFlexData,
- )
- triton_weight_tensor, weight_scale = swizzle_mxfp4(
- triton_weight_tensor, weight_scale, triton_kernels_hub
- )
- proj = "gate_up_proj" if "gate_up_proj" in param_name else "down_proj"
- setattr(module, proj, triton_weight_tensor)
- setattr(
- module,
- f"{proj}_precision_config",
- PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
- )
- delattr(module, f"{proj}_blocks")
- delattr(module, f"{proj}_scales")
- # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
- else:
- # This is when loading a quantized model (blocks and scales exist)
- empty_param = kwargs.get("empty_param")
- casting_dtype = kwargs.get("casting_dtype")
- to_contiguous = kwargs.get("to_contiguous")
- rank = kwargs.get("rank")
- device_mesh = kwargs.get("device_mesh")
- if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
- # blocks and scales have the same length that's why this works for both
- module, _ = get_module_from_name(model, param_name[: -len("_blocks")])
- else:
- module, _ = get_module_from_name(model, param_name)
- shard_kwargs = {
- "empty_param": empty_param,
- "casting_dtype": casting_dtype,
- "to_contiguous": to_contiguous,
- "rank": rank,
- "device_mesh": device_mesh,
- "model": model,
- }
- if isinstance(module, Mxfp4GptOssExperts) or (
- isinstance(module, GptOssExperts) and self.quantization_config.dequantize
- ):
- if self.quantization_config.dequantize:
- # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears
- # so we only have the original param name
- dq_param_name = param_name[: -len("_blocks")]
- dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs)
- else:
- load_and_swizzle_mxfp4(
- module,
- param_name,
- param_value,
- target_device,
- self._lazy_import_kernels(),
- **shard_kwargs,
- )
- def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
- # we are not really dequantizing, we are just removing everything related to quantization here
- if self.quantization_config.dequantize:
- self.remove_quantization_config(model)
- # clean cache due to triton ops
- if torch.cuda.is_available():
- torch.cuda.empty_cache()
- elif torch.xpu.is_available():
- torch.xpu.empty_cache()
- def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]):
- # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants
- new_expected_keys = []
- for key in expected_keys:
- if key.endswith(".mlp.experts.gate_up_proj"):
- base = key[: -len("gate_up_proj")]
- new_expected_keys.append(base + "gate_up_proj_blocks")
- new_expected_keys.append(base + "gate_up_proj_scales")
- elif key.endswith(".mlp.experts.down_proj"):
- base = key[: -len("down_proj")]
- new_expected_keys.append(base + "down_proj_blocks")
- new_expected_keys.append(base + "down_proj_scales")
- elif not self.pre_quantized:
- # in this case, we are quantizing the model so we need to update the keys as we changed the layers
- if key.endswith(".mlp.experts.down_proj_blocks"):
- base = key[: -len("down_proj_blocks")]
- new_expected_keys.append(base + "down_proj")
- elif key.endswith(".mlp.experts.gate_up_proj_blocks"):
- base = key[: -len("gate_up_proj_blocks")]
- new_expected_keys.append(base + "gate_up_proj")
- elif key.endswith("scales"):
- # we remove it the scales as the checkpoint don't contain them
- continue
- else:
- new_expected_keys.append(key)
- else:
- new_expected_keys.append(key)
- return new_expected_keys
- def _process_model_before_weight_loading(
- self,
- model: "PreTrainedModel",
- keep_in_fp32_modules: Optional[list[str]] = None,
- **kwargs,
- ):
- from ..integrations import replace_with_mxfp4_linear
- self.modules_to_not_convert = self.get_modules_to_not_convert(
- model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
- )
- use_kernels = kwargs.get("use_kernels", False)
- # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
- if use_kernels:
- logger.warning_once(
- "You are using full precision kernels, we will dequantize the model to bf16. "
- "To use the quantized model with quantization kernels, please set use_kernels=False"
- )
- self.quantization_config.dequantize = True
- config = model.config
- model = replace_with_mxfp4_linear(
- model,
- modules_to_not_convert=self.modules_to_not_convert,
- quantization_config=self.quantization_config,
- config=config,
- )
- model.config.quantization_config = self.quantization_config
- def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
- from ..integrations import Mxfp4GptOssExperts
- not_missing_keys = []
- for name, module in model.named_modules():
- if isinstance(module, Mxfp4GptOssExperts):
- for missing in missing_keys:
- if (
- (name in missing or name in f"{prefix}.{missing}")
- and not missing.endswith(".weight")
- and not missing.endswith(".bias")
- ):
- not_missing_keys.append(missing)
- return [k for k in missing_keys if k not in not_missing_keys]
- def update_tp_plan(self, config):
- if "GptOssConfig" in config.__class__.__name__:
- if getattr(config, "base_model_tp_plan", None) is not None:
- config.base_model_tp_plan.update(
- {
- "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
- "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
- "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
- "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
- }
- )
- return config
- def update_ep_plan(self, config):
- if "GptOssConfig" in config.__class__.__name__:
- if getattr(config, "base_model_ep_plan", None) is not None:
- config.base_model_ep_plan.update(
- {
- "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
- "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
- "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
- "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
- }
- )
- return config
- def get_param_name(self, param_name: str) -> str:
- if self.quantization_config.dequantize:
- if "_blocks" in param_name:
- return param_name.replace("_blocks", "")
- elif "_scales" in param_name:
- return param_name.replace("_scales", "")
- elif not self.pre_quantized:
- if param_name.endswith("gate_up_proj"):
- return param_name.replace("gate_up_proj", "gate_up_proj_blocks")
- if param_name.endswith("down_proj"):
- return param_name.replace("down_proj", "down_proj_blocks")
- return param_name
- def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
- from ..integrations import Mxfp4GptOssExperts
- state_dict = model.state_dict()
- for name, module in model.named_modules():
- if (
- isinstance(module, Mxfp4GptOssExperts)
- and hasattr(module, "gate_up_proj")
- and hasattr(module, "down_proj")
- ):
- state_dict[f"{name}.gate_up_proj_blocks"] = (
- module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data)
- .transpose(-1, -2)
- .reshape(32, -1, 90, 16)
- )
- state_dict[f"{name}.gate_up_proj_scales"] = (
- module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
- module.gate_up_proj_precision_config.weight_scale.storage.data
- ).transpose(-1, -2)
- )
- state_dict[f"{name}.down_proj_blocks"] = (
- module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data)
- .transpose(-1, -2)
- .reshape(32, 2880, 90, -1)
- )
- state_dict[f"{name}.down_proj_scales"] = (
- module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
- module.down_proj_precision_config.weight_scale.storage.data
- ).transpose(-1, -2)
- )
- metadata = {}
- return state_dict, metadata
- def is_serializable(self, safe_serialization=None):
- return True
- @property
- def is_trainable(self) -> bool:
- logger.warning_once(
- "MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()"
- )
- return False
|