finegrained_fp8.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from typing import Optional
  16. from ..utils import is_accelerate_available, is_torch_accelerator_available, is_torch_available, logging
  17. if is_torch_available():
  18. import torch
  19. import torch.nn as nn
  20. import triton
  21. import triton.language as tl
  22. from torch.nn import functional as F
  23. if is_accelerate_available():
  24. from accelerate import init_empty_weights
  25. logger = logging.get_logger(__name__)
  26. # Copied from https://huggingface.co/deepseek-ai/DeepSeek-V3/blob/main/inference/kernel.py
  27. @triton.jit
  28. def act_quant_kernel(x_ptr, y_ptr, s_ptr, BLOCK_SIZE: tl.constexpr):
  29. pid = tl.program_id(axis=0)
  30. offs = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
  31. x = tl.load(x_ptr + offs).to(tl.float32)
  32. s = tl.max(tl.abs(x)) / 448.0
  33. y = x / s
  34. y = y.to(y_ptr.dtype.element_ty)
  35. tl.store(y_ptr + offs, y)
  36. tl.store(s_ptr + pid, s)
  37. def act_quant(x: torch.Tensor, block_size: int = 128) -> tuple[torch.Tensor, torch.Tensor]:
  38. assert x.is_contiguous()
  39. assert x.shape[-1] % block_size == 0
  40. y = torch.empty_like(x, dtype=torch.float8_e4m3fn)
  41. s = x.new_empty(*x.size()[:-1], x.size(-1) // block_size, dtype=torch.float32)
  42. def grid(meta):
  43. return (triton.cdiv(x.numel(), meta["BLOCK_SIZE"]),)
  44. act_quant_kernel[grid](x, y, s, BLOCK_SIZE=block_size)
  45. return y, s
  46. # Adapted from https://github.com/sgl-project/sglang/blob/main/python/sglang/srt/layers/quantization/fp8_kernel.py
  47. @triton.jit
  48. def _w8a8_block_fp8_matmul(
  49. # Pointers to inputs and output
  50. A,
  51. B,
  52. C,
  53. As,
  54. Bs,
  55. # Shape for matmul
  56. M,
  57. N,
  58. K,
  59. # Block size for block-wise quantization
  60. group_n,
  61. group_k,
  62. # Stride for inputs and output
  63. stride_am,
  64. stride_ak,
  65. stride_bk,
  66. stride_bn,
  67. stride_cm,
  68. stride_cn,
  69. stride_As_m,
  70. stride_As_k,
  71. stride_Bs_k,
  72. stride_Bs_n,
  73. # Meta-parameters
  74. BLOCK_SIZE_M: tl.constexpr,
  75. BLOCK_SIZE_N: tl.constexpr,
  76. BLOCK_SIZE_K: tl.constexpr,
  77. GROUP_SIZE_M: tl.constexpr,
  78. ):
  79. """Triton-accelerated function used to perform linear operations (dot
  80. product) on input tensors `A` and `B` with block-wise quantization, and
  81. store the result in output tensor `C`.
  82. """
  83. pid = tl.program_id(axis=0)
  84. num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
  85. num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
  86. num_pid_in_group = GROUP_SIZE_M * num_pid_n
  87. group_id = pid // num_pid_in_group
  88. first_pid_m = group_id * GROUP_SIZE_M
  89. group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
  90. pid_m = first_pid_m + (pid % group_size_m)
  91. pid_n = (pid % num_pid_in_group) // group_size_m
  92. offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
  93. offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
  94. offs_k = tl.arange(0, BLOCK_SIZE_K)
  95. a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
  96. b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
  97. As_ptrs = As + offs_am * stride_As_m
  98. offs_bsn = offs_bn // group_n
  99. Bs_ptrs = Bs + offs_bsn * stride_Bs_n
  100. accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
  101. for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
  102. a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
  103. b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
  104. k_start = k * BLOCK_SIZE_K
  105. offs_ks = k_start // group_k
  106. a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
  107. b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
  108. accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
  109. a_ptrs += BLOCK_SIZE_K * stride_ak
  110. b_ptrs += BLOCK_SIZE_K * stride_bk
  111. if C.dtype.element_ty == tl.bfloat16:
  112. c = accumulator.to(tl.bfloat16)
  113. elif C.dtype.element_ty == tl.float16:
  114. c = accumulator.to(tl.float16)
  115. else:
  116. c = accumulator.to(tl.float32)
  117. offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
  118. offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
  119. c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
  120. c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
  121. tl.store(c_ptrs, c, mask=c_mask)
  122. def w8a8_block_fp8_matmul_triton(
  123. A: torch.Tensor,
  124. B: torch.Tensor,
  125. As: torch.Tensor,
  126. Bs: torch.Tensor,
  127. block_size: list[int],
  128. output_dtype: torch.dtype = torch.float32,
  129. ) -> torch.Tensor:
  130. """This function performs matrix multiplication with block-wise
  131. quantization.
  132. It takes two input tensors `A` and `B` with scales `As` and `Bs`.
  133. The output is returned in the specified `output_dtype`.
  134. Args:
  135. A: The input tensor, e.g., activation.
  136. B: The input tensor, e.g., weight.
  137. As: The per-token-group quantization scale for `A`.
  138. Bs: The per-block quantization scale for `B`.
  139. block_size: The block size for per-block quantization. It should
  140. be 2-dim, e.g., [128, 128].
  141. output_dytpe: The dtype of the returned tensor.
  142. Returns:
  143. torch.Tensor: The result of matmul.
  144. """
  145. assert len(block_size) == 2
  146. block_n, block_k = block_size[0], block_size[1]
  147. assert A.shape[-1] == B.shape[-1]
  148. assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
  149. assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
  150. M = A.numel() // A.shape[-1]
  151. assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
  152. N, K = B.shape
  153. assert triton.cdiv(N, block_n) == Bs.shape[0]
  154. assert triton.cdiv(K, block_k) == Bs.shape[1]
  155. C_shape = A.shape[:-1] + (N,)
  156. C = A.new_empty(C_shape, dtype=output_dtype)
  157. BLOCK_SIZE_M = 128
  158. if M < BLOCK_SIZE_M:
  159. BLOCK_SIZE_M = triton.next_power_of_2(M)
  160. BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
  161. BLOCK_SIZE_K = block_k
  162. assert block_k % BLOCK_SIZE_K == 0
  163. BLOCK_SIZE_N = block_n
  164. def grid(META):
  165. return (triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),)
  166. _w8a8_block_fp8_matmul[grid](
  167. A,
  168. B,
  169. C,
  170. As,
  171. Bs,
  172. M,
  173. N,
  174. K,
  175. block_n,
  176. block_k,
  177. A.stride(-2),
  178. A.stride(-1),
  179. B.stride(1),
  180. B.stride(0),
  181. C.stride(-2),
  182. C.stride(-1),
  183. As.stride(-2),
  184. As.stride(-1),
  185. Bs.stride(1),
  186. Bs.stride(0),
  187. BLOCK_SIZE_M=BLOCK_SIZE_M,
  188. BLOCK_SIZE_N=BLOCK_SIZE_N,
  189. BLOCK_SIZE_K=BLOCK_SIZE_K,
  190. GROUP_SIZE_M=8,
  191. )
  192. return C
  193. # Python version of the above triton function, it's much slower than the triton version, for testing
  194. @torch.compile
  195. def w8a8_block_fp8_matmul_compile(
  196. input_q: torch.Tensor, # [batch, seq_len, hidden_dim]
  197. weight_q: torch.Tensor, # [out_features, hidden_dim]
  198. input_scale: torch.Tensor, # [batch * seq_len, num_input_groups]
  199. weight_scale: torch.Tensor, # [num_weight_blocks_m, num_weight_blocks_n]
  200. block_size: Optional[tuple[int, int]] = None, # (M=128, N=128) for weights for example
  201. output_dtype: torch.dtype = torch.float32,
  202. ) -> torch.Tensor:
  203. """
  204. Performs blocked matrix multiplication with FP8 quantized matrices.
  205. Args:
  206. input_q: Quantized input tensor with 1x128 block quantization
  207. weight_q: Quantized weight tensor with 128x128 block quantization
  208. input_scale: Scaling factors for input blocks
  209. weight_scale: Scaling factors for weight blocks
  210. block_size: Tuple of (M, N) for weight block dimensions
  211. output_dtype: Desired output dtype
  212. """
  213. batch_size, seq_len, hidden_dim = input_q.shape if input_q.ndim == 3 else (1, input_q.shape[0], input_q.shape[1])
  214. out_features = weight_q.shape[0]
  215. # Reshape input for batched matmul
  216. input_reshaped = input_q.view(-1, hidden_dim) # [batch*seq_len, hidden_dim]
  217. input_scale_reshaped = input_scale.view(input_scale.shape[0], -1) # [batch*seq_len, 1]
  218. # Calculate number of blocks
  219. num_weight_blocks_m = out_features // block_size[0]
  220. num_weight_blocks_n = hidden_dim // block_size[1]
  221. output = torch.zeros((batch_size * seq_len, out_features), dtype=torch.float32, device=input_q.device)
  222. for i in range(num_weight_blocks_m):
  223. m_start = i * block_size[0]
  224. m_end = m_start + block_size[0]
  225. for j in range(num_weight_blocks_n):
  226. n_start = j * block_size[1]
  227. n_end = n_start + block_size[1]
  228. # Extract current blocks
  229. input_block = input_reshaped[:, n_start:n_end]
  230. weight_block = weight_q[m_start:m_end, n_start:n_end]
  231. # Get corresponding scales
  232. curr_input_scale = input_scale_reshaped[:, j : j + 1] # [batch*seq_len, 1]
  233. curr_weight_scale = weight_scale[i, j] # scalar
  234. block_result = (
  235. torch._scaled_mm(
  236. input_block,
  237. weight_block.t(),
  238. scale_a=torch.tensor(1, dtype=torch.float32, device=input_q.device),
  239. scale_b=curr_weight_scale,
  240. out_dtype=output_dtype,
  241. )
  242. * curr_input_scale
  243. )
  244. output[:, m_start:m_end] += block_result
  245. output = output.view(batch_size, seq_len, out_features)
  246. return output.to(output_dtype)
  247. class FP8Linear(nn.Linear):
  248. dtype = torch.float8_e4m3fn
  249. def __init__(
  250. self,
  251. in_features: int,
  252. out_features: int,
  253. bias: bool = False,
  254. dtype=None,
  255. block_size: Optional[tuple[int, int]] = None,
  256. device=None,
  257. activation_scheme="dynamic",
  258. ):
  259. super().__init__(in_features, out_features)
  260. self.in_features = in_features
  261. self.out_features = out_features
  262. self.weight = torch.nn.Parameter(torch.empty(out_features, in_features, dtype=FP8Linear.dtype, device=device))
  263. if self.weight.element_size() == 1:
  264. scale_out_features = (out_features + block_size[0] - 1) // block_size[0]
  265. scale_in_features = (in_features + block_size[1] - 1) // block_size[1]
  266. self.weight_scale_inv = nn.Parameter(
  267. torch.empty(scale_out_features, scale_in_features, dtype=torch.float32, device=device)
  268. )
  269. else:
  270. self.register_parameter("weight_scale_inv", None)
  271. self.block_size = block_size
  272. self.activation_scheme = activation_scheme
  273. if bias:
  274. self.bias = nn.Parameter(torch.empty(self.out_features))
  275. else:
  276. self.register_parameter("bias", None)
  277. def forward(self, input: torch.Tensor) -> torch.Tensor:
  278. if self.weight.element_size() > 1:
  279. return F.linear(input, self.weight, self.bias)
  280. else:
  281. # Context manager used to switch among the available accelerators
  282. device_type = torch.accelerator.current_accelerator().type if is_torch_accelerator_available() else "cuda"
  283. torch_accelerator_module = getattr(torch, device_type, torch.cuda)
  284. with torch_accelerator_module.device(input.device):
  285. qinput, scale = act_quant(input, self.block_size[1])
  286. output = w8a8_block_fp8_matmul_triton(
  287. qinput,
  288. self.weight,
  289. scale,
  290. self.weight_scale_inv,
  291. self.block_size,
  292. output_dtype=input.dtype,
  293. )
  294. # Blocks the CPU until all accelerator operations on the specified device are complete. It is used to ensure that the results of the
  295. # preceding operations are ready before proceeding
  296. torch_accelerator_module.synchronize()
  297. if self.bias is not None:
  298. output = output + self.bias
  299. return output.to(dtype=input.dtype)
  300. def _replace_with_fp8_linear(
  301. model,
  302. tp_plan=None,
  303. modules_to_not_convert=None,
  304. current_key_name=None,
  305. quantization_config=None,
  306. has_been_replaced=False,
  307. ):
  308. """Replace Linear layers with FP8Linear."""
  309. if current_key_name is None:
  310. current_key_name = []
  311. for name, module in model.named_children():
  312. current_key_name.append(name)
  313. if isinstance(module, nn.Linear) and name not in (modules_to_not_convert or []):
  314. current_key_name_str = ".".join(current_key_name)
  315. if not any(key in current_key_name_str for key in (modules_to_not_convert or [])):
  316. with init_empty_weights():
  317. model._modules[name] = FP8Linear(
  318. in_features=module.in_features,
  319. out_features=module.out_features,
  320. bias=module.bias is not None,
  321. device=module.weight.device,
  322. dtype=module.weight.dtype,
  323. activation_scheme=quantization_config.activation_scheme,
  324. block_size=quantization_config.weight_block_size,
  325. )
  326. has_been_replaced = True
  327. # when changing a layer the TP PLAN for that layer should be updated. TODO
  328. if len(list(module.children())) > 0:
  329. _, has_been_replaced = _replace_with_fp8_linear(
  330. module,
  331. tp_plan,
  332. modules_to_not_convert,
  333. current_key_name,
  334. quantization_config,
  335. has_been_replaced=has_been_replaced,
  336. )
  337. current_key_name.pop(-1)
  338. return model, has_been_replaced
  339. def replace_with_fp8_linear(
  340. model,
  341. modules_to_not_convert=None,
  342. quantization_config=None,
  343. ):
  344. """Helper function to replace model layers with FP8 versions."""
  345. modules_to_not_convert = ["lm_head"] if modules_to_not_convert is None else modules_to_not_convert
  346. if quantization_config.modules_to_not_convert is not None:
  347. modules_to_not_convert.extend(quantization_config.modules_to_not_convert)
  348. modules_to_not_convert = list(set(modules_to_not_convert))
  349. model, has_been_replaced = _replace_with_fp8_linear(
  350. model,
  351. tp_plan=model._tp_plan,
  352. modules_to_not_convert=modules_to_not_convert,
  353. quantization_config=quantization_config,
  354. )
  355. if not has_been_replaced:
  356. logger.warning(
  357. "You are loading your model using fp8 but no linear modules were found in your model."
  358. " Please double check your model architecture."
  359. )
  360. return model