Skip to content

Commit f0e6334

Browse files
jiawenliu64facebook-github-bot
authored andcommitted
Enable MXFP8 grouped GEMM
Summary: As title Differential Revision: D80350093
1 parent 8c442ac commit f0e6334

12 files changed

+1492
-0
lines changed

fbgemm_gpu/experimental/gemm/triton_gemm/fp8_gemm.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4359,3 +4359,74 @@ def grid(meta):
43594359
xq, x_scale, x_dequant, M, K, BLOCK_M=block_m, BLOCK_K=block_k # pyre-ignore[6]
43604360
)
43614361
return x_dequant
4362+
4363+
4364+
# This function is extracted from https://github.com/pytorch/ao/blob/v0.12.0/torchao/prototype/mx_formats/mx_tensor.py#L142
4365+
def to_mxfp8(
4366+
data_hp: torch.Tensor,
4367+
block_size: int = 32,
4368+
):
4369+
assert data_hp.dtype in (
4370+
torch.bfloat16,
4371+
torch.float,
4372+
), f"{data_hp.dtype} is not supported yet"
4373+
assert (
4374+
data_hp.shape[-1] % block_size == 0
4375+
), f"the last dimension of shape {data_hp.shape} must be divisible by block_size {block_size}"
4376+
assert data_hp.is_contiguous(), "unsupported"
4377+
4378+
orig_shape = data_hp.shape
4379+
data_hp = data_hp.reshape(
4380+
*orig_shape[:-1], orig_shape[-1] // block_size, block_size
4381+
)
4382+
4383+
max_abs = torch.amax(torch.abs(data_hp), -1).unsqueeze(-1)
4384+
4385+
data_hp = data_hp.to(torch.float32)
4386+
max_abs = max_abs.to(torch.float32)
4387+
4388+
F8E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max # 448.0
4389+
max_pos = F8E4M3_MAX
4390+
4391+
# RCEIL
4392+
def _to_mx_rceil(
4393+
data_hp: torch.Tensor,
4394+
max_abs: torch.Tensor,
4395+
max_pos: float,
4396+
) -> tuple[torch.Tensor, torch.Tensor]:
4397+
E8M0_EXPONENT_BIAS = 127
4398+
descale = max_abs / max_pos
4399+
exponent = torch.where(
4400+
torch.isnan(descale),
4401+
0xFF, # Handle biased exponent for nan
4402+
# NOTE: descale < (torch.finfo(torch.float32).smallest_normal / 2) is handled through clamping
4403+
(
4404+
torch.clamp(
4405+
torch.ceil(torch.log2(descale)),
4406+
min=-E8M0_EXPONENT_BIAS,
4407+
max=E8M0_EXPONENT_BIAS,
4408+
)
4409+
+ E8M0_EXPONENT_BIAS
4410+
).to(torch.uint8),
4411+
)
4412+
4413+
descale_fp = torch.where(
4414+
exponent == 0,
4415+
1.0,
4416+
torch.exp2(E8M0_EXPONENT_BIAS - exponent.to(torch.float32)),
4417+
)
4418+
4419+
# scale and saturated cast the data elements to max of target dtype
4420+
data_lp = torch.clamp(data_hp * descale_fp, min=-1 * max_pos, max=max_pos)
4421+
return exponent, data_lp
4422+
4423+
scale_e8m0_biased, data_lp = _to_mx_rceil(data_hp, max_abs, max_pos)
4424+
4425+
# cast to target dtype
4426+
data_lp = data_lp.to(torch.float8_e4m3fn)
4427+
# need to reshape at the end to help inductor fuse things
4428+
data_lp = data_lp.reshape(orig_shape)
4429+
4430+
scale_e8m0_biased = scale_e8m0_biased.view(torch.float8_e8m0fnu)
4431+
scale_e8m0_biased = scale_e8m0_biased.squeeze(-1)
4432+
return scale_e8m0_biased, data_lp

fbgemm_gpu/experimental/gen_ai/bench/quantize_ops.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import triton # @manual=//triton:triton
1616

1717
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import (
18+
_to_blocked,
1819
calculate_group_max,
1920
mega_fp4_pack,
2021
mega_fp4_quantize_kernel,
@@ -33,6 +34,7 @@
3334
quantize_fp8_group,
3435
quantize_fp8_row,
3536
scale_fp8_row,
37+
to_mxfp8,
3638
triton_quantize_fp8_row,
3739
)
3840
from fbgemm_gpu.experimental.gemm.triton_gemm.grouped_gemm import (
@@ -2868,3 +2870,86 @@ def hip(self) -> bool:
28682870
@property
28692871
def cuda(self) -> bool:
28702872
return True
2873+
2874+
2875+
@register_quantize_op
2876+
class MXFP8StackedGroupedGemm(QuantizeOpBase):
2877+
"""
2878+
MXFP8 grouped matmul with blockwise scaling and stacked inputs.
2879+
"""
2880+
2881+
def preprocess(self, x, w):
2882+
m_values = [i.shape[0] for i in x]
2883+
m_sizes = torch.tensor(m_values).to(dtype=torch.int64, device=x[0].device)
2884+
wq_list = []
2885+
w_scale_list = []
2886+
for i in range(m_sizes.shape[0]):
2887+
w_scale, wq = to_mxfp8(w[i])
2888+
w_scale = _to_blocked(w_scale)
2889+
wq_list.append(wq)
2890+
w_scale_list.append(w_scale)
2891+
wq = torch.stack(wq_list, dim=0).contiguous()
2892+
w_scale = torch.stack(w_scale_list, dim=0).contiguous()
2893+
return x, wq, w_scale, m_sizes
2894+
2895+
def quantize(self, x, wq, w_scale, m_sizes):
2896+
starting_row_after_padding_list = [0]
2897+
xq_list = []
2898+
x_scale_list = []
2899+
for i in range(m_sizes.shape[0]):
2900+
scale_slice = x[i]
2901+
if m_sizes[i].item() != 0:
2902+
x_scale, xq = to_mxfp8(scale_slice)
2903+
x_scale = _to_blocked(x_scale)
2904+
xq_list.append(xq)
2905+
x_scale_list.append(x_scale)
2906+
starting_row_after_padding_list.append(
2907+
starting_row_after_padding_list[i]
2908+
+ x_scale.numel() // (x[0].shape[1] // 32)
2909+
)
2910+
else:
2911+
starting_row_after_padding_list.append(
2912+
starting_row_after_padding_list[i]
2913+
)
2914+
xq = torch.cat(xq_list, dim=0).contiguous()
2915+
x_scale = torch.cat(x_scale_list, dim=0).contiguous()
2916+
x_scale = x_scale.reshape(-1, x[0].shape[-1] // 32)
2917+
xq = xq.view(-1, xq.shape[-1])
2918+
return (
2919+
xq,
2920+
wq,
2921+
x_scale,
2922+
w_scale,
2923+
m_sizes,
2924+
torch.tensor(starting_row_after_padding_list, device=xq.device),
2925+
)
2926+
2927+
def compute(self, xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding):
2928+
return torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
2929+
xq,
2930+
wq,
2931+
x_scale,
2932+
w_scale,
2933+
m_sizes,
2934+
starting_row_after_padding=starting_row_after_padding,
2935+
)
2936+
2937+
def quantize_and_compute(self, x, w):
2938+
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding = self.quantize(
2939+
x, w
2940+
)
2941+
return self.compute(
2942+
xq, wq, x_scale, w_scale, m_sizes, starting_row_after_padding
2943+
)
2944+
2945+
@property
2946+
def name(self) -> str:
2947+
return "cutlass_mx8mx8bf16_grouped_stacked"
2948+
2949+
@property
2950+
def hip(self) -> bool:
2951+
return False
2952+
2953+
@property
2954+
def cuda(self) -> bool:
2955+
return True

0 commit comments

Comments
 (0)