quantizer_mxfp4.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424
  1. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Optional
  15. from .base import HfQuantizer
  16. if TYPE_CHECKING:
  17. from ..modeling_utils import PreTrainedModel
  18. from ..utils import (
  19. is_accelerate_available,
  20. is_kernels_available,
  21. is_torch_available,
  22. is_triton_available,
  23. logging,
  24. )
  25. from .quantizers_utils import get_module_from_name
  26. if is_torch_available():
  27. import torch
  28. logger = logging.get_logger(__name__)
  29. triton_kernels_hub = None
  30. class Mxfp4HfQuantizer(HfQuantizer):
  31. """
  32. FP4 quantization using fbgemm kernels
  33. """
  34. requires_parameters_quantization = True
  35. requires_calibration = False
  36. required_packages = ["accelerate"]
  37. def __init__(self, quantization_config, **kwargs):
  38. super().__init__(quantization_config, **kwargs)
  39. self.quantization_config = quantization_config
  40. self.triton_kernels_hub = None
  41. def _lazy_import_kernels(self):
  42. """Lazy import and initialize kernels only when needed"""
  43. if self.triton_kernels_hub is None:
  44. try:
  45. from kernels import get_kernel
  46. self.triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
  47. except ImportError:
  48. raise ImportError("kernels package is required for MXFP4 quantization")
  49. return self.triton_kernels_hub
  50. def validate_environment(self, *args, **kwargs):
  51. if not is_torch_available():
  52. raise ImportError(
  53. "Using mxfp4 quantization requires torch"
  54. "Please install the latest version of torch ( pip install --upgrade torch )"
  55. )
  56. if self.quantization_config.dequantize:
  57. return
  58. if not (torch.cuda.is_available() or torch.xpu.is_available()):
  59. if self.pre_quantized:
  60. logger.warning_once(
  61. "Using MXFP4 quantized models requires a GPU, we will default to dequantizing the model to bf16"
  62. )
  63. self.quantization_config.dequantize = True
  64. return
  65. else:
  66. raise RuntimeError("Quantizing a model using MXFP4 requires a GPU")
  67. if not is_accelerate_available():
  68. raise ImportError("Using mxfp4 requires Accelerate: `pip install accelerate`")
  69. if torch.xpu.is_available():
  70. gpu_is_supported = True
  71. kernels_available = is_triton_available("3.5.0") and is_kernels_available()
  72. else:
  73. compute_capability = torch.cuda.get_device_capability()
  74. gpu_is_supported = compute_capability >= (7, 5)
  75. kernels_available = is_triton_available("3.4.0") and is_kernels_available()
  76. if self.pre_quantized:
  77. # On unsupported GPUs or without kernels, we will dequantize the model to bf16
  78. if not gpu_is_supported:
  79. logger.warning_once(
  80. "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series) "
  81. "We will default to dequantizing the model to bf16."
  82. )
  83. self.quantization_config.dequantize = True
  84. return
  85. if not kernels_available:
  86. logger.warning_once(
  87. "MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0, we will default to dequantizing the model to bf16"
  88. )
  89. self.quantization_config.dequantize = True
  90. return
  91. elif not gpu_is_supported:
  92. # we can't quantize the model in this case so we raise an error
  93. raise ValueError(
  94. "MXFP4 quantization is only supported on GPUs with compute capability >= 7.5 (e.g T4, A100, L4, H100, or B200) or XPUs (e.g Intel® Data Center GPU Max Series) "
  95. )
  96. elif not kernels_available:
  97. # we can't quantize the model in this case so we raise an error
  98. raise ValueError(
  99. "MXFP4 quantization requires Triton and kernels installed: CUDA requires Triton >= 3.4.0, XPU requires Triton >= 3.5.0"
  100. )
  101. if not self.pre_quantized:
  102. self._lazy_import_kernels()
  103. device_map = kwargs.get("device_map")
  104. if device_map is None:
  105. logger.warning_once(
  106. "You have loaded an FP4 model on CPU and have a CUDA/XPU device available, make sure to set "
  107. "your model on a GPU/XPU device in order to run your model. To remove this warning, pass device_map = 'cuda' or device_map = 'xpu'. "
  108. )
  109. elif device_map is not None:
  110. if (
  111. not self.pre_quantized
  112. and isinstance(device_map, dict)
  113. and ("cpu" in device_map.values() or "disk" in device_map.values())
  114. ):
  115. raise ValueError(
  116. "You are attempting to load an FP4 model with a device_map that contains a CPU or disk device."
  117. "This is not supported when the model is quantized on the fly. "
  118. "Please use a quantized checkpoint or remove the CPU or disk device from the device_map."
  119. )
  120. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  121. if dtype is None:
  122. dtype = torch.bfloat16
  123. logger.info(
  124. "Overriding dtype=%s with `dtype=torch.bfloat16` due to "
  125. "requirements of `fbgemm-gpu` to enable model loading in fp4. "
  126. "Pass your own dtype to specify the dtype of the remaining non-linear layers or pass"
  127. " dtype=torch.bfloat16 to remove this warning.",
  128. dtype,
  129. )
  130. return dtype
  131. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  132. from ..integrations import Mxfp4GptOssExperts
  133. from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
  134. # if we are dequantizing, the model doesn't have scales, and blocks only params like gate_up_proj and down_proj so we need to handle this case differently
  135. if self.quantization_config.dequantize and ("blocks" in param_name or "scales" in param_name):
  136. module, tensor_name = get_module_from_name(model, param_name[: -len("_blocks")])
  137. else:
  138. module, tensor_name = get_module_from_name(model, param_name)
  139. if isinstance(module, Mxfp4GptOssExperts) or (
  140. isinstance(module, GptOssExperts) and self.quantization_config.dequantize
  141. ):
  142. if tensor_name in ["down_proj_bias", "gate_up_proj_bias"]:
  143. return False
  144. return True
  145. return False
  146. def create_quantized_param(
  147. self,
  148. model: "PreTrainedModel",
  149. param_value: "torch.Tensor",
  150. param_name: str,
  151. target_device: "torch.device",
  152. **kwargs,
  153. ):
  154. from ..integrations import (
  155. Mxfp4GptOssExperts,
  156. dequantize,
  157. load_and_swizzle_mxfp4,
  158. quantize_to_mxfp4,
  159. swizzle_mxfp4,
  160. )
  161. from ..models.gpt_oss.modeling_gpt_oss import GptOssExperts
  162. if not self.pre_quantized:
  163. triton_kernels_hub = self._lazy_import_kernels()
  164. module, _ = get_module_from_name(model, param_name)
  165. with torch.device(target_device):
  166. if isinstance(module, Mxfp4GptOssExperts):
  167. triton_weight_tensor, weight_scale = quantize_to_mxfp4(param_value, triton_kernels_hub)
  168. PrecisionConfig, FlexCtx, InFlexData = (
  169. triton_kernels_hub.matmul_ogs.PrecisionConfig,
  170. triton_kernels_hub.matmul_ogs.FlexCtx,
  171. triton_kernels_hub.matmul_ogs.InFlexData,
  172. )
  173. triton_weight_tensor, weight_scale = swizzle_mxfp4(
  174. triton_weight_tensor, weight_scale, triton_kernels_hub
  175. )
  176. proj = "gate_up_proj" if "gate_up_proj" in param_name else "down_proj"
  177. setattr(module, proj, triton_weight_tensor)
  178. setattr(
  179. module,
  180. f"{proj}_precision_config",
  181. PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
  182. )
  183. delattr(module, f"{proj}_blocks")
  184. delattr(module, f"{proj}_scales")
  185. # The params going here are either gate_up_proj_blocks, or down_proj_blocks, or gate_up_proj_scales, or down_proj_scales
  186. else:
  187. # This is when loading a quantized model (blocks and scales exist)
  188. empty_param = kwargs.get("empty_param")
  189. casting_dtype = kwargs.get("casting_dtype")
  190. to_contiguous = kwargs.get("to_contiguous")
  191. rank = kwargs.get("rank")
  192. device_mesh = kwargs.get("device_mesh")
  193. if ("blocks" in param_name or "scales" in param_name) and self.quantization_config.dequantize:
  194. # blocks and scales have the same length that's why this works for both
  195. module, _ = get_module_from_name(model, param_name[: -len("_blocks")])
  196. else:
  197. module, _ = get_module_from_name(model, param_name)
  198. shard_kwargs = {
  199. "empty_param": empty_param,
  200. "casting_dtype": casting_dtype,
  201. "to_contiguous": to_contiguous,
  202. "rank": rank,
  203. "device_mesh": device_mesh,
  204. "model": model,
  205. }
  206. if isinstance(module, Mxfp4GptOssExperts) or (
  207. isinstance(module, GptOssExperts) and self.quantization_config.dequantize
  208. ):
  209. if self.quantization_config.dequantize:
  210. # dq_param_name is the name of the parameter without the blocks or scales suffix, it's used in this case since we don't switch linears
  211. # so we only have the original param name
  212. dq_param_name = param_name[: -len("_blocks")]
  213. dequantize(module, param_name, param_value, target_device, dq_param_name, **shard_kwargs)
  214. else:
  215. load_and_swizzle_mxfp4(
  216. module,
  217. param_name,
  218. param_value,
  219. target_device,
  220. self._lazy_import_kernels(),
  221. **shard_kwargs,
  222. )
  223. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  224. # we are not really dequantizing, we are just removing everything related to quantization here
  225. if self.quantization_config.dequantize:
  226. self.remove_quantization_config(model)
  227. # clean cache due to triton ops
  228. if torch.cuda.is_available():
  229. torch.cuda.empty_cache()
  230. elif torch.xpu.is_available():
  231. torch.xpu.empty_cache()
  232. def update_expected_keys(self, model: "PreTrainedModel", expected_keys: list[str], checkpoint_keys: list[str]):
  233. # Replace expected_keys for experts' gate_up_proj and down_proj with their _blocks and _scales variants
  234. new_expected_keys = []
  235. for key in expected_keys:
  236. if key.endswith(".mlp.experts.gate_up_proj"):
  237. base = key[: -len("gate_up_proj")]
  238. new_expected_keys.append(base + "gate_up_proj_blocks")
  239. new_expected_keys.append(base + "gate_up_proj_scales")
  240. elif key.endswith(".mlp.experts.down_proj"):
  241. base = key[: -len("down_proj")]
  242. new_expected_keys.append(base + "down_proj_blocks")
  243. new_expected_keys.append(base + "down_proj_scales")
  244. elif not self.pre_quantized:
  245. # in this case, we are quantizing the model so we need to update the keys as we changed the layers
  246. if key.endswith(".mlp.experts.down_proj_blocks"):
  247. base = key[: -len("down_proj_blocks")]
  248. new_expected_keys.append(base + "down_proj")
  249. elif key.endswith(".mlp.experts.gate_up_proj_blocks"):
  250. base = key[: -len("gate_up_proj_blocks")]
  251. new_expected_keys.append(base + "gate_up_proj")
  252. elif key.endswith("scales"):
  253. # we remove it the scales as the checkpoint don't contain them
  254. continue
  255. else:
  256. new_expected_keys.append(key)
  257. else:
  258. new_expected_keys.append(key)
  259. return new_expected_keys
  260. def _process_model_before_weight_loading(
  261. self,
  262. model: "PreTrainedModel",
  263. keep_in_fp32_modules: Optional[list[str]] = None,
  264. **kwargs,
  265. ):
  266. from ..integrations import replace_with_mxfp4_linear
  267. self.modules_to_not_convert = self.get_modules_to_not_convert(
  268. model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
  269. )
  270. use_kernels = kwargs.get("use_kernels", False)
  271. # if we are using kernels, we can't use the quantized model, since the forward pass is different and needs special handling
  272. if use_kernels:
  273. logger.warning_once(
  274. "You are using full precision kernels, we will dequantize the model to bf16. "
  275. "To use the quantized model with quantization kernels, please set use_kernels=False"
  276. )
  277. self.quantization_config.dequantize = True
  278. config = model.config
  279. model = replace_with_mxfp4_linear(
  280. model,
  281. modules_to_not_convert=self.modules_to_not_convert,
  282. quantization_config=self.quantization_config,
  283. config=config,
  284. )
  285. model.config.quantization_config = self.quantization_config
  286. def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
  287. from ..integrations import Mxfp4GptOssExperts
  288. not_missing_keys = []
  289. for name, module in model.named_modules():
  290. if isinstance(module, Mxfp4GptOssExperts):
  291. for missing in missing_keys:
  292. if (
  293. (name in missing or name in f"{prefix}.{missing}")
  294. and not missing.endswith(".weight")
  295. and not missing.endswith(".bias")
  296. ):
  297. not_missing_keys.append(missing)
  298. return [k for k in missing_keys if k not in not_missing_keys]
  299. def update_tp_plan(self, config):
  300. if "GptOssConfig" in config.__class__.__name__:
  301. if getattr(config, "base_model_tp_plan", None) is not None:
  302. config.base_model_tp_plan.update(
  303. {
  304. "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
  305. "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
  306. "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
  307. "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
  308. }
  309. )
  310. return config
  311. def update_ep_plan(self, config):
  312. if "GptOssConfig" in config.__class__.__name__:
  313. if getattr(config, "base_model_ep_plan", None) is not None:
  314. config.base_model_ep_plan.update(
  315. {
  316. "layers.*.mlp.experts.gate_up_proj_blocks": "grouped_gemm",
  317. "layers.*.mlp.experts.gate_up_proj_scales": "grouped_gemm",
  318. "layers.*.mlp.experts.down_proj_blocks": "grouped_gemm",
  319. "layers.*.mlp.experts.down_proj_scales": "grouped_gemm",
  320. }
  321. )
  322. return config
  323. def get_param_name(self, param_name: str) -> str:
  324. if self.quantization_config.dequantize:
  325. if "_blocks" in param_name:
  326. return param_name.replace("_blocks", "")
  327. elif "_scales" in param_name:
  328. return param_name.replace("_scales", "")
  329. elif not self.pre_quantized:
  330. if param_name.endswith("gate_up_proj"):
  331. return param_name.replace("gate_up_proj", "gate_up_proj_blocks")
  332. if param_name.endswith("down_proj"):
  333. return param_name.replace("down_proj", "down_proj_blocks")
  334. return param_name
  335. def get_state_dict_and_metadata(self, model, safe_serialization: bool = False):
  336. from ..integrations import Mxfp4GptOssExperts
  337. state_dict = model.state_dict()
  338. for name, module in model.named_modules():
  339. if (
  340. isinstance(module, Mxfp4GptOssExperts)
  341. and hasattr(module, "gate_up_proj")
  342. and hasattr(module, "down_proj")
  343. ):
  344. state_dict[f"{name}.gate_up_proj_blocks"] = (
  345. module.gate_up_proj.storage.layout.unswizzle_data(module.gate_up_proj.storage.data)
  346. .transpose(-1, -2)
  347. .reshape(32, -1, 90, 16)
  348. )
  349. state_dict[f"{name}.gate_up_proj_scales"] = (
  350. module.gate_up_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
  351. module.gate_up_proj_precision_config.weight_scale.storage.data
  352. ).transpose(-1, -2)
  353. )
  354. state_dict[f"{name}.down_proj_blocks"] = (
  355. module.down_proj.storage.layout.unswizzle_data(module.down_proj.storage.data)
  356. .transpose(-1, -2)
  357. .reshape(32, 2880, 90, -1)
  358. )
  359. state_dict[f"{name}.down_proj_scales"] = (
  360. module.down_proj_precision_config.weight_scale.storage.layout.unswizzle_data(
  361. module.down_proj_precision_config.weight_scale.storage.data
  362. ).transpose(-1, -2)
  363. )
  364. metadata = {}
  365. return state_dict, metadata
  366. def is_serializable(self, safe_serialization=None):
  367. return True
  368. @property
  369. def is_trainable(self) -> bool:
  370. logger.warning_once(
  371. "MXFP4 quantization don't support training, please consider dequantizing the model first by passing quantization_config=Mxfp4Config(dequantize=True) to .from_pretrained()"
  372. )
  373. return False