mxfp4.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..utils import is_accelerate_available, is_torch_available, logging
  15. if is_torch_available():
  16. import torch
  17. from torch import nn
  18. if is_accelerate_available():
  19. from accelerate import init_empty_weights
  20. import re
  21. from contextlib import contextmanager
  22. logger = logging.get_logger(__name__)
  23. FP4_VALUES = [
  24. +0.0,
  25. +0.5,
  26. +1.0,
  27. +1.5,
  28. +2.0,
  29. +3.0,
  30. +4.0,
  31. +6.0,
  32. -0.0,
  33. -0.5,
  34. -1.0,
  35. -1.5,
  36. -2.0,
  37. -3.0,
  38. -4.0,
  39. -6.0,
  40. ]
  41. @contextmanager
  42. def on_device(dev):
  43. if is_torch_available():
  44. import torch
  45. if isinstance(dev, torch.Tensor):
  46. dev = dev.device
  47. elif isinstance(dev, str):
  48. dev = torch.device(dev)
  49. dev_type = getattr(dev, "type", None)
  50. if dev_type == "cuda":
  51. with torch.cuda.device(dev):
  52. yield
  53. return
  54. if dev_type == "xpu" and hasattr(torch, "xpu"):
  55. with torch.xpu.device(dev):
  56. yield
  57. return
  58. # other: CPU
  59. yield
  60. # Copied from GPT_OSS repo and vllm
  61. def quantize_to_mxfp4(w, triton_kernels_hub):
  62. downcast_to_mxfp_torch = triton_kernels_hub.numerics_details.mxfp.downcast_to_mxfp_torch
  63. w, w_scale = downcast_to_mxfp_torch(w.to(torch.bfloat16), torch.uint8, axis=1)
  64. return w, w_scale
  65. def swizzle_mxfp4(w, w_scale, triton_kernels_hub):
  66. """
  67. Changes the layout of the tensors depending on the hardware
  68. """
  69. FP4, convert_layout, wrap_torch_tensor = (
  70. triton_kernels_hub.tensor.FP4,
  71. triton_kernels_hub.tensor.convert_layout,
  72. triton_kernels_hub.tensor.wrap_torch_tensor,
  73. )
  74. layout = triton_kernels_hub.tensor_details.layout
  75. StridedLayout = triton_kernels_hub.tensor_details.layout.StridedLayout
  76. value_layout, value_layout_opts = layout.make_default_matmul_mxfp4_w_layout(mx_axis=1)
  77. w = convert_layout(wrap_torch_tensor(w, dtype=FP4), value_layout, **value_layout_opts)
  78. w_scale = convert_layout(wrap_torch_tensor(w_scale), StridedLayout)
  79. return w, w_scale
  80. # Copied from GPT_OSS repo
  81. # TODO: Add absolute link when the repo is public
  82. def convert_moe_packed_tensors(
  83. blocks,
  84. scales,
  85. *,
  86. dtype: torch.dtype = torch.bfloat16,
  87. rows_per_chunk: int = 32768 * 1024, # TODO these values are not here by mistake ;)
  88. ) -> torch.Tensor:
  89. """
  90. Convert the mxfp4 weights again, dequantizing and makes them compatible with the forward
  91. pass of GPT_OSS.
  92. """
  93. import math
  94. # Check if blocks and scales are on CPU, and move to GPU if so
  95. if not blocks.is_cuda and torch.cuda.is_available():
  96. blocks = blocks.cuda()
  97. scales = scales.cuda()
  98. scales = scales.to(torch.int32) - 127 # TODO that's because 128=2**7
  99. assert blocks.shape[:-1] == scales.shape, f"{blocks.shape[:-1]=} does not match {scales.shape=}"
  100. lut = torch.tensor(FP4_VALUES, dtype=dtype, device=blocks.device)
  101. *prefix_shape, G, B = blocks.shape
  102. rows_total = math.prod(prefix_shape) * G
  103. blocks = blocks.reshape(rows_total, B)
  104. scales = scales.reshape(rows_total, 1)
  105. out = torch.empty(rows_total, B * 2, dtype=dtype, device=blocks.device)
  106. for r0 in range(0, rows_total, rows_per_chunk):
  107. r1 = min(r0 + rows_per_chunk, rows_total)
  108. blk = blocks[r0:r1]
  109. exp = scales[r0:r1]
  110. # nibble indices -> int64
  111. idx_lo = (blk & 0x0F).to(torch.long)
  112. idx_hi = (blk >> 4).to(torch.long)
  113. sub = out[r0:r1]
  114. sub[:, 0::2] = lut[idx_lo]
  115. sub[:, 1::2] = lut[idx_hi]
  116. torch.ldexp(sub, exp, out=sub)
  117. del idx_lo, idx_hi, blk, exp, sub
  118. out = out.reshape(*prefix_shape, G, B * 2).view(*prefix_shape, G * B * 2)
  119. del blocks, scales, lut
  120. return out.transpose(1, 2).contiguous()
  121. class Mxfp4GptOssExperts(nn.Module):
  122. def __init__(self, config):
  123. super().__init__()
  124. self.num_experts = config.num_local_experts
  125. self.intermediate_size = config.intermediate_size
  126. self.hidden_size = config.hidden_size
  127. self.gate_up_proj_blocks = nn.Parameter(
  128. torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, 16, dtype=torch.uint8),
  129. requires_grad=False,
  130. )
  131. self.gate_up_proj_scales = nn.Parameter(
  132. torch.zeros(self.num_experts, 2 * self.intermediate_size, self.hidden_size // 32, dtype=torch.uint8),
  133. requires_grad=False,
  134. )
  135. self.gate_up_proj_bias = nn.Parameter(
  136. torch.zeros(self.num_experts, 2 * self.intermediate_size, dtype=torch.float32), requires_grad=False
  137. )
  138. self.down_proj_blocks = nn.Parameter(
  139. torch.zeros((self.num_experts, self.hidden_size, self.intermediate_size // 32, 16), dtype=torch.uint8),
  140. requires_grad=False,
  141. )
  142. self.down_proj_scales = nn.Parameter(
  143. torch.zeros(self.num_experts, self.hidden_size, self.intermediate_size // 32, dtype=torch.uint8),
  144. requires_grad=False,
  145. )
  146. self.down_proj_bias = nn.Parameter(
  147. torch.zeros(self.num_experts, self.hidden_size, dtype=torch.float32), requires_grad=False
  148. )
  149. self.alpha = 1.702
  150. self.limit = getattr(config, "swiglu_limit", 7.0)
  151. self.gate_up_proj_precision_config = None
  152. self.down_proj_precision_config = None
  153. self.limit = getattr(config, "swiglu_limit", 7.0)
  154. def forward(self, hidden_states: torch.Tensor, routing_data, gather_idx, scatter_idx) -> torch.Tensor:
  155. FnSpecs, FusedActivation, matmul_ogs = (
  156. triton_kernels_hub.matmul_ogs.FnSpecs,
  157. triton_kernels_hub.matmul_ogs.FusedActivation,
  158. triton_kernels_hub.matmul_ogs.matmul_ogs,
  159. )
  160. swiglu_fn = triton_kernels_hub.swiglu.swiglu_fn
  161. with on_device(hidden_states.device):
  162. act = FusedActivation(FnSpecs("swiglu", swiglu_fn, ("alpha", "limit")), (self.alpha, self.limit), 2)
  163. intermediate_cache1 = matmul_ogs(
  164. hidden_states,
  165. self.gate_up_proj,
  166. self.gate_up_proj_bias.to(torch.float32),
  167. routing_data,
  168. gather_indx=gather_idx,
  169. precision_config=self.gate_up_proj_precision_config,
  170. gammas=None,
  171. fused_activation=act,
  172. )
  173. intermediate_cache3 = matmul_ogs(
  174. intermediate_cache1,
  175. self.down_proj,
  176. self.down_proj_bias.to(torch.float32),
  177. routing_data,
  178. scatter_indx=scatter_idx,
  179. precision_config=self.down_proj_precision_config,
  180. gammas=routing_data.gate_scal,
  181. )
  182. return intermediate_cache3
  183. # Adapted from GPT_OSS repo
  184. # TODO: Add absolute link when the repo is public
  185. def routing_torch_dist(
  186. logits,
  187. n_expts_act,
  188. ):
  189. import os
  190. GatherIndx, RoutingData, ScatterIndx, compute_expt_data_torch = (
  191. triton_kernels_hub.routing.GatherIndx,
  192. triton_kernels_hub.routing.RoutingData,
  193. triton_kernels_hub.routing.ScatterIndx,
  194. triton_kernels_hub.routing.compute_expt_data_torch,
  195. )
  196. with on_device(logits.device):
  197. world_size = torch.distributed.get_world_size()
  198. rank = int(os.environ.get("LOCAL_RANK", "0"))
  199. replace_value = -1
  200. n_tokens = logits.shape[0]
  201. n_expts_tot = logits.shape[1]
  202. n_local_experts = n_expts_tot // world_size
  203. local_expert_start = rank * n_local_experts
  204. local_expert_end = (rank + 1) * n_local_experts
  205. n_gates_pad = n_tokens * n_expts_act
  206. def topk(vals, k):
  207. tk_indx = torch.argsort(-vals, dim=1, stable=True)[:, :k]
  208. tk_indx = tk_indx.long()
  209. tk_val = torch.take_along_dim(vals, tk_indx, dim=1)
  210. return tk_val, tk_indx.int()
  211. expt_scal, expt_indx = topk(logits, n_expts_act)
  212. expt_scal = torch.softmax(expt_scal, dim=-1)
  213. expt_indx, sort_indices = torch.sort(expt_indx, dim=1)
  214. expt_scal = torch.gather(expt_scal, 1, sort_indices)
  215. # Flatten and mask for local experts
  216. expt_scal = expt_scal.reshape(-1)
  217. hist = torch.histc(expt_indx, bins=n_expts_tot, max=n_expts_tot - 1)[local_expert_start:local_expert_end]
  218. expt_indx = expt_indx.view(-1).to(torch.int32)
  219. # we use a large value to replace the indices that are not in the local expert range
  220. var = 1000
  221. expt_indx = torch.where(expt_indx < local_expert_start, var, expt_indx)
  222. topk_indx = torch.argsort(expt_indx, stable=True).to(torch.int32)
  223. gate_indx = torch.argsort(topk_indx).to(torch.int32)
  224. expt_indx = torch.where(expt_indx < local_expert_end, expt_indx, replace_value)
  225. expt_indx = torch.where(local_expert_start <= expt_indx, expt_indx, replace_value)
  226. gate_indx = torch.where(expt_indx == replace_value, replace_value, gate_indx)
  227. gate_scal = expt_scal[topk_indx]
  228. topk_indx = torch.where(gate_indx[topk_indx] == replace_value, replace_value, topk_indx)
  229. # # Routing metadata for local expert computation
  230. gather_indx = GatherIndx(src_indx=topk_indx.int(), dst_indx=gate_indx.int())
  231. scatter_indx = ScatterIndx(src_indx=gate_indx.int(), dst_indx=topk_indx.int())
  232. expt_data = compute_expt_data_torch(hist, n_local_experts, n_gates_pad)
  233. hit_experts = n_expts_act
  234. return RoutingData(gate_scal, hist, n_local_experts, hit_experts, expt_data), gather_indx, scatter_indx
  235. def mlp_forward(self, hidden_states):
  236. import torch.distributed as dist
  237. if dist.is_available() and dist.is_initialized() and hasattr(self, "_is_hooked"):
  238. routing = routing_torch_dist
  239. else:
  240. routing = triton_kernels_hub.routing.routing
  241. batch_size = hidden_states.shape[0]
  242. hidden_states = hidden_states.reshape(-1, self.router.hidden_dim)
  243. router_logits = nn.functional.linear(hidden_states, self.router.weight, self.router.bias)
  244. with on_device(router_logits.device):
  245. routing_data, gather_idx, scatter_idx = routing(router_logits, self.router.top_k)
  246. routed_out = self.experts(hidden_states, routing_data, gather_idx, scatter_idx)
  247. routed_out = routed_out.reshape(batch_size, -1, self.router.hidden_dim)
  248. return routed_out, router_logits
  249. def should_convert_module(current_key_name, patterns):
  250. current_key_name_str = ".".join(current_key_name)
  251. if not any(
  252. re.match(f"{key}\\.", current_key_name_str) or re.match(f"{key}", current_key_name_str) for key in patterns
  253. ):
  254. return True
  255. return False
  256. def dequantize(module, param_name, param_value, target_device, dq_param_name, **kwargs):
  257. from ..integrations.tensor_parallel import shard_and_distribute_module
  258. model = kwargs.get("model")
  259. empty_param = kwargs.get("empty_param")
  260. casting_dtype = kwargs.get("casting_dtype")
  261. to_contiguous = kwargs.get("to_contiguous")
  262. rank = kwargs.get("rank")
  263. device_mesh = kwargs.get("device_mesh")
  264. for proj in ["gate_up_proj", "down_proj"]:
  265. if proj in param_name:
  266. if device_mesh is not None:
  267. param_value = shard_and_distribute_module(
  268. model,
  269. param_value,
  270. empty_param,
  271. dq_param_name,
  272. casting_dtype,
  273. to_contiguous,
  274. rank,
  275. device_mesh,
  276. )
  277. blocks_attr = f"{proj}_blocks"
  278. scales_attr = f"{proj}_scales"
  279. setattr(module, param_name.rsplit(".", 1)[1], param_value)
  280. if hasattr(module, blocks_attr) and hasattr(module, scales_attr):
  281. dequantized = convert_moe_packed_tensors(getattr(module, blocks_attr), getattr(module, scales_attr))
  282. if target_device == "cpu" and torch.cuda.is_available():
  283. torch.cuda.empty_cache()
  284. setattr(module, proj, torch.nn.Parameter(dequantized.to(target_device)))
  285. delattr(module, blocks_attr)
  286. delattr(module, scales_attr)
  287. def load_and_swizzle_mxfp4(module, param_name, param_value, target_device, triton_kernels_hub, **kwargs):
  288. """
  289. This transforms the weights obtained using `convert_gpt_oss.py` to load them into `Mxfp4GptOssExperts`.
  290. """
  291. PrecisionConfig, FlexCtx, InFlexData = (
  292. triton_kernels_hub.matmul_ogs.PrecisionConfig,
  293. triton_kernels_hub.matmul_ogs.FlexCtx,
  294. triton_kernels_hub.matmul_ogs.InFlexData,
  295. )
  296. from ..integrations.tensor_parallel import shard_and_distribute_module
  297. model = kwargs.get("model")
  298. empty_param = kwargs.get("empty_param")
  299. casting_dtype = kwargs.get("casting_dtype")
  300. to_contiguous = kwargs.get("to_contiguous")
  301. rank = kwargs.get("rank")
  302. device_mesh = kwargs.get("device_mesh")
  303. if "blocks" in param_name:
  304. proj = param_name.split(".")[-1].split("_blocks")[0]
  305. if "scales" in param_name:
  306. proj = param_name.split(".")[-1].split("_scales")[0]
  307. if device_mesh is not None:
  308. shard_and_distribute_module(
  309. model, param_value, empty_param, param_name, casting_dtype, to_contiguous, rank, device_mesh
  310. )
  311. else:
  312. setattr(module, param_name.rsplit(".", 1)[1], torch.nn.Parameter(param_value, requires_grad=False))
  313. blocks_attr = f"{proj}_blocks"
  314. scales_attr = f"{proj}_scales"
  315. blocks = getattr(module, blocks_attr) # at this point values were loaded from ckpt
  316. scales = getattr(module, scales_attr)
  317. # Check if both blocks and scales both not on meta device
  318. if blocks.device.type != "meta" and scales.device.type != "meta":
  319. local_experts = blocks.size(0)
  320. if proj == "gate_up_proj":
  321. blocks = blocks.reshape(local_experts, module.intermediate_size * 2, -1)
  322. else:
  323. blocks = blocks.reshape(local_experts, -1, module.intermediate_size // 2)
  324. if getattr(target_device, "type", target_device) == "cpu":
  325. target_device = "cuda"
  326. blocks = blocks.to(target_device).contiguous()
  327. scales = scales.to(target_device).contiguous()
  328. with on_device(target_device):
  329. triton_weight_tensor, weight_scale = swizzle_mxfp4(
  330. blocks.transpose(-2, -1), scales.transpose(-2, -1), triton_kernels_hub
  331. )
  332. # need to overwrite the shapes for the kernels
  333. if proj == "gate_up_proj":
  334. triton_weight_tensor.shape = torch.Size([local_experts, module.hidden_size, module.intermediate_size * 2])
  335. else:
  336. triton_weight_tensor.shape = torch.Size([local_experts, module.intermediate_size, module.hidden_size])
  337. # triton_weight_tensor is what needs to be passed in oai kernels. It stores the data, the shapes and any more objects. It is like a subtensor
  338. setattr(module, proj, triton_weight_tensor)
  339. setattr(
  340. module,
  341. f"{proj}_precision_config",
  342. PrecisionConfig(weight_scale=weight_scale, flex_ctx=FlexCtx(rhs_data=InFlexData())),
  343. )
  344. # delete blocks and scales
  345. delattr(module, scales_attr)
  346. delattr(module, blocks_attr)
  347. del blocks
  348. def _replace_with_mxfp4_linear(
  349. model,
  350. modules_to_not_convert=None,
  351. current_key_name=None,
  352. quantization_config=None,
  353. has_been_replaced=False,
  354. config=None,
  355. ):
  356. if current_key_name is None:
  357. current_key_name = []
  358. for name, module in model.named_children():
  359. current_key_name.append(name)
  360. if not should_convert_module(current_key_name, modules_to_not_convert):
  361. current_key_name.pop(-1)
  362. continue
  363. if module.__class__.__name__ == "GptOssExperts" and not quantization_config.dequantize:
  364. with init_empty_weights():
  365. model._modules[name] = Mxfp4GptOssExperts(config)
  366. has_been_replaced = True
  367. if module.__class__.__name__ == "GptOssMLP" and not quantization_config.dequantize:
  368. from types import MethodType
  369. module.forward = MethodType(mlp_forward, module)
  370. if len(list(module.children())) > 0:
  371. _, has_been_replaced = _replace_with_mxfp4_linear(
  372. module,
  373. modules_to_not_convert,
  374. current_key_name,
  375. quantization_config,
  376. has_been_replaced=has_been_replaced,
  377. config=config,
  378. )
  379. current_key_name.pop(-1)
  380. return model, has_been_replaced
  381. def replace_with_mxfp4_linear(
  382. model,
  383. modules_to_not_convert=None,
  384. current_key_name=None,
  385. quantization_config=None,
  386. config=None,
  387. ):
  388. if quantization_config.dequantize:
  389. return model
  390. else:
  391. from kernels import get_kernel
  392. global triton_kernels_hub
  393. triton_kernels_hub = get_kernel("kernels-community/triton_kernels")
  394. modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
  395. if quantization_config.modules_to_not_convert is not None:
  396. modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
  397. modules_to_not_convert = list(set(modules_to_not_convert))
  398. model, has_been_replaced = _replace_with_mxfp4_linear(
  399. model,
  400. modules_to_not_convert,
  401. current_key_name,
  402. quantization_config,
  403. config=config,
  404. )
  405. if not has_been_replaced:
  406. logger.warning(
  407. "You are loading your model using mixed-precision FP4 quantization but no linear modules were found in your model."
  408. " Please double check your model architecture, or submit an issue on github if you think this is"
  409. " a bug."
  410. )
  411. return model