quantizer_finegrained_fp8.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228
  1. from typing import TYPE_CHECKING, Optional
  2. from ..utils import is_accelerate_available, is_torch_available, is_torch_xpu_available, logging
  3. from .base import HfQuantizer
  4. from .quantizers_utils import get_module_from_name
  5. if is_torch_available():
  6. import torch
  7. if TYPE_CHECKING:
  8. from ..modeling_utils import PreTrainedModel
  9. logger = logging.get_logger(__name__)
  10. class FineGrainedFP8HfQuantizer(HfQuantizer):
  11. """
  12. FP8 quantization implementation supporting both standard and MoE models.
  13. Supports both e4m3fn formats based on platform.
  14. """
  15. requires_parameters_quantization = True
  16. requires_calibration = False
  17. required_packages = ["accelerate"]
  18. def __init__(self, quantization_config, **kwargs):
  19. super().__init__(quantization_config, **kwargs)
  20. self.quantization_config = quantization_config
  21. def validate_environment(self, *args, **kwargs):
  22. if not is_torch_available():
  23. raise ImportError(
  24. "Using fp8 quantization requires torch >= 2.1.0"
  25. "Please install the latest version of torch ( pip install --upgrade torch )"
  26. )
  27. if not is_accelerate_available():
  28. raise ImportError("Loading an FP8 quantized model requires accelerate (`pip install accelerate`)")
  29. if kwargs.get("from_tf", False) or kwargs.get("from_flax", False):
  30. raise ValueError(
  31. "Converting into FP8 weights from tf/flax weights is currently not supported, "
  32. "please make sure the weights are in PyTorch format."
  33. )
  34. if not (torch.cuda.is_available() or is_torch_xpu_available()):
  35. raise RuntimeError("No GPU or XPU found. A GPU or XPU is needed for FP8 quantization.")
  36. if torch.cuda.is_available():
  37. compute_capability = torch.cuda.get_device_capability()
  38. major, minor = compute_capability
  39. if (major < 8) or (major == 8 and minor < 9):
  40. raise ValueError(
  41. "FP8 quantized models is only supported on GPUs with compute capability >= 8.9 (e.g 4090/H100)"
  42. f", actual = `{major}.{minor}`"
  43. )
  44. device_map = kwargs.get("device_map")
  45. if device_map is None:
  46. logger.warning_once(
  47. "You have loaded an FP8 model on CPU and have a CUDA or XPU device available, make sure to set "
  48. "your model on a GPU or XPU device in order to run your model. To remove this warning, "
  49. "pass device_map = 'cuda' or 'xpu'. "
  50. )
  51. elif device_map is not None:
  52. if (
  53. not self.pre_quantized
  54. and isinstance(device_map, dict)
  55. and ("cpu" in device_map.values() or "disk" in device_map.values())
  56. ):
  57. raise ValueError(
  58. "You are attempting to load an FP8 model with a device_map that contains a cpu/disk device."
  59. "This is not supported when the model is quantized on the fly. "
  60. "Please use a quantized checkpoint or remove the cpu/disk device from the device_map."
  61. )
  62. def update_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
  63. if dtype is None:
  64. logger.info("Setting dtype to torch.float32 as no dtype was specified in from_pretrained")
  65. dtype = torch.float32
  66. return dtype
  67. def create_quantized_param(
  68. self,
  69. model: "PreTrainedModel",
  70. param_value: "torch.Tensor",
  71. param_name: str,
  72. target_device: "torch.device",
  73. **kwargs,
  74. ):
  75. from ..integrations.finegrained_fp8 import FP8Linear
  76. from ..modeling_utils import _load_parameter_into_model
  77. # Sanity checks
  78. module, tensor_name = get_module_from_name(model, param_name)
  79. if isinstance(module, FP8Linear):
  80. if self.pre_quantized or tensor_name == "bias":
  81. if tensor_name == "weight" and param_value.dtype != torch.float8_e4m3fn:
  82. raise ValueError("Expect quantized weights but got an unquantized weight")
  83. else:
  84. if tensor_name == "weight_scale_inv":
  85. raise ValueError("Expect unquantized weights but got a quantized weight_scale")
  86. param_value = param_value.to(target_device)
  87. # Get FP8 min/max values
  88. fp8_min = torch.finfo(torch.float8_e4m3fn).min
  89. fp8_max = torch.finfo(torch.float8_e4m3fn).max
  90. block_size_m, block_size_n = self.quantization_config.weight_block_size
  91. rows, cols = param_value.shape[-2:]
  92. if rows % block_size_m != 0 or cols % block_size_n != 0:
  93. raise ValueError(
  94. f"Matrix dimensions ({rows}, {cols}) must be divisible by block sizes ({block_size_m}, {block_size_n})"
  95. )
  96. param_value_orig_shape = param_value.shape
  97. param_value = param_value.reshape(
  98. -1, rows // block_size_m, block_size_m, cols // block_size_n, block_size_n
  99. ).permute(0, 1, 3, 2, 4)
  100. # Calculate scaling factor for each block
  101. max_abs = torch.amax(torch.abs(param_value), dim=(-1, -2))
  102. scale = fp8_max / max_abs
  103. scale_orig_shape = scale.shape
  104. scale = scale.unsqueeze(-1).unsqueeze(-1)
  105. # Quantize the weights
  106. quantized_param = torch.clamp(param_value * scale, min=fp8_min, max=fp8_max).to(torch.float8_e4m3fn)
  107. quantized_param = quantized_param.permute(0, 1, 3, 2, 4)
  108. # Reshape back to matrix shape
  109. quantized_param = quantized_param.reshape(param_value_orig_shape)
  110. # Reshape scale to match the number of blocks
  111. scale = scale.reshape(scale_orig_shape).squeeze().reciprocal()
  112. # Load into the model
  113. _load_parameter_into_model(model, param_name, quantized_param)
  114. _load_parameter_into_model(model, param_name.rsplit(".", 1)[0] + ".weight_scale_inv", scale)
  115. def param_needs_quantization(self, model: "PreTrainedModel", param_name: str, **kwargs) -> bool:
  116. from ..integrations.finegrained_fp8 import FP8Linear
  117. module, tensor_name = get_module_from_name(model, param_name)
  118. if isinstance(module, FP8Linear):
  119. if self.pre_quantized or tensor_name == "bias":
  120. return False
  121. else:
  122. return True
  123. return False
  124. def _process_model_before_weight_loading(
  125. self,
  126. model: "PreTrainedModel",
  127. keep_in_fp32_modules: Optional[list[str]] = None,
  128. **kwargs,
  129. ):
  130. from ..integrations.finegrained_fp8 import replace_with_fp8_linear
  131. self.modules_to_not_convert = self.get_modules_to_not_convert(
  132. model, self.quantization_config.modules_to_not_convert, keep_in_fp32_modules
  133. )
  134. model = replace_with_fp8_linear(
  135. model,
  136. modules_to_not_convert=self.modules_to_not_convert,
  137. quantization_config=self.quantization_config,
  138. )
  139. model.config.quantization_config = self.quantization_config
  140. def _process_model_after_weight_loading(self, model: "PreTrainedModel", **kwargs):
  141. return model
  142. def update_missing_keys(self, model, missing_keys: list[str], prefix: str) -> list[str]:
  143. from ..integrations import FP8Linear
  144. not_missing_keys = []
  145. for name, module in model.named_modules():
  146. if isinstance(module, FP8Linear):
  147. for missing in missing_keys:
  148. if (
  149. (name in missing or name in f"{prefix}.{missing}")
  150. and not missing.endswith(".weight")
  151. and not missing.endswith(".bias")
  152. ):
  153. not_missing_keys.append(missing)
  154. return [k for k in missing_keys if k not in not_missing_keys]
  155. def update_tp_plan(self, config):
  156. if "Qwen3" in config.__class__.__name__:
  157. text_plan = {
  158. "layers.*.self_attn.q_proj.weight": "local_colwise",
  159. "layers.*.self_attn.q_proj.weight_scale_inv": "local_colwise",
  160. "layers.*.self_attn.k_proj.weight": "local_colwise",
  161. "layers.*.self_attn.k_proj.weight_scale_inv": "local_colwise",
  162. "layers.*.self_attn.v_proj.weight": "local_colwise",
  163. "layers.*.self_attn.v_proj.weight_scale_inv": "local_colwise",
  164. "layers.*.self_attn.o_proj.weight": "local_rowwise",
  165. "layers.*.self_attn.o_proj.weight_scale_inv": "local_rowwise",
  166. "layers.*.self_attn": "gather",
  167. "layers.*.mlp.gate_proj.weight": "local_colwise",
  168. "layers.*.mlp.gate_proj.weight_scale_inv": "local_colwise",
  169. "layers.*.mlp.up_proj.weight": "local_colwise",
  170. "layers.*.mlp.up_proj.weight_scale_inv": "local_colwise",
  171. "layers.*.mlp.down_proj.weight": "local_rowwise",
  172. "layers.*.mlp.down_proj.weight_scale_inv": "local_rowwise",
  173. "layers.*.mlp": "gather",
  174. }
  175. config.base_model_tp_plan = text_plan
  176. return config
  177. def is_serializable(self, safe_serialization=None):
  178. return True
  179. @property
  180. def is_trainable(self) -> bool:
  181. return False
  182. def get_accelerator_warm_up_factor(self):
  183. # Pre-processing is done cleanly, so we can allocate everything here
  184. return 2