Skip to content

Commit ad54e2b

Browse files
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
1 parent b663faf commit ad54e2b

File tree

6 files changed

+280
-52
lines changed

6 files changed

+280
-52
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -230,25 +230,26 @@ def compute_reference_forward(
230230
@pytest.mark.parametrize("num_experts", (1, 8, 16))
231231
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
232232
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
233-
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
233+
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
234234
offs = generate_jagged_offs(num_experts, M)
235-
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
235+
x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone()
236236

237237
# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
238238
block_size = 32
239-
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
239+
x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
240240

241241
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
242-
w_scale, w_mx = to_mx(
243-
w_t.transpose(-2, -1).contiguous(),
242+
w_scale, w_fp8 = to_mx(
243+
w,
244244
elem_dtype=torch.float8_e4m3fn,
245245
block_size=block_size,
246246
)
247-
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
248247

249-
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
248+
ref_out = torch._grouped_mm(
249+
x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16
250+
)
250251
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
251-
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
252+
x_fp8, x_scale, w_fp8, w_scale, offs=offs, out_dtype=torch.bfloat16
252253
)
253254

254255
sqnr = compute_error(ref_out, out)
@@ -305,18 +306,25 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
305306

306307

307308
@skip_if_rocm("ROCm not supported")
308-
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
309-
@pytest.mark.parametrize("num_experts", (1, 8, 16))
309+
@pytest.mark.parametrize(
310+
"M,K,N", [(1024, 5120, 8192), (2048, 5120, 8192), (16640, 5120, 8192)]
311+
)
312+
@pytest.mark.parametrize("num_experts", (2, 4, 8, 16))
310313
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
311314
from torchao.prototype.moe_training.scaled_grouped_mm import (
312315
_MXFP8GroupedMM,
313316
)
314317

315318
block_size = 32
316319
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
317-
w_t = torch.randn(
318-
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
320+
w = torch.randn(
321+
num_experts,
322+
N,
323+
K,
324+
dtype=torch.bfloat16,
325+
device="cuda",
319326
)
327+
w_t = w.transpose(-2, -1).requires_grad_(True)
320328
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
321329
x_ref, w_t_ref, offs_ref = (
322330
x.clone().detach().requires_grad_(True),

test/prototype/moe_training/test_training.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
129129
)
130130

131131

132+
@pytest.mark.skip(
133+
"temporarily disable until non-uniform group sizes are supported by mxfp8 grouped gemm"
134+
)
132135
@pytest.mark.parametrize(
133136
"target_fqns",
134137
[

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10+
from torchao.prototype.moe_training.kernels.mxfp8 import (
11+
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
12+
)
Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import logging
2+
3+
import torch
4+
5+
from torchao.prototype.mx_formats.utils import (
6+
to_blocked_per_group_2d,
7+
to_blocked_per_group_3d,
8+
)
9+
10+
logger: logging.Logger = logging.getLogger(__name__)
11+
12+
try:
13+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
14+
except Exception as e:
15+
logging.warning(
16+
f"fbgemm_gpu_genai package is required for this feature but import failed with exception: {e}"
17+
"Please install nightly builds of pytorch and fbgemm_gpu_genai build using this command and try again: "
18+
"pip3 install --force-reinstall --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu129"
19+
"If errors persist, please file a bug report."
20+
)
21+
22+
23+
@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
24+
def fbgemm_mxfp8_grouped_mm_2d_3d(
25+
A_fp8: torch.Tensor,
26+
A_scales: torch.Tensor,
27+
B_fp8: torch.Tensor,
28+
B_scales: torch.Tensor,
29+
offs: torch.Tensor,
30+
block_size: int = 32,
31+
out_dtype: torch.dtype = torch.bfloat16,
32+
) -> torch.Tensor:
33+
assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D"
34+
assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D"
35+
assert block_size == 32, "Only block_size=32 is supported"
36+
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
37+
assert A_fp8.shape[-1] == B_fp8.shape[-1], "A_fp8 and B_fp8 must have same last dim"
38+
39+
# Convert scales for each group to blocked format.
40+
Mg, K = A_fp8.shape
41+
A_scales_blocked, starting_row_after_padding = to_blocked_per_group_2d(
42+
A_scales, offs, Mg, K
43+
)
44+
B_scales_blocked = to_blocked_per_group_3d(B_scales)
45+
46+
# From this, we compute `group_sizes` and `starting_row_after_padding`:
47+
# group_sizes = [32, 32, 64]
48+
# starting_row_after_padding = [0, 32, 64, 128]
49+
zero = torch.tensor([0], dtype=offs.dtype, device=offs.device)
50+
group_sizes = torch.diff(offs, prepend=zero).to(torch.int64)
51+
52+
# TODO: remove debug logging once prototype is more mature.
53+
_log_inputs(
54+
A_fp8,
55+
B_fp8,
56+
A_scales,
57+
A_scales_blocked,
58+
B_scales,
59+
B_scales_blocked,
60+
offs,
61+
group_sizes,
62+
starting_row_after_padding,
63+
)
64+
65+
out = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
66+
A_fp8,
67+
B_fp8,
68+
A_scales_blocked,
69+
B_scales_blocked,
70+
group_sizes,
71+
starting_row_after_padding=starting_row_after_padding,
72+
)
73+
return out
74+
75+
76+
@fbgemm_mxfp8_grouped_mm_2d_3d.register_fake
77+
def _fbgemm_mxfp8_grouped_mm_2d_3d_fake(
78+
A_fp8: torch.Tensor,
79+
B_fp8: torch.Tensor,
80+
A_scales: torch.Tensor,
81+
B_scales: torch.Tensor,
82+
offs: torch.Tensor,
83+
) -> torch.Tensor:
84+
assert A_fp8.ndim == 2, "A_fp8 tensor must be 2D"
85+
assert B_fp8.ndim == 3, "B_fp8 tensor must be 3D"
86+
mg, k = A_fp8.shape
87+
e, k, n = B_fp8.shape
88+
n_groups = offs.numel()
89+
assert n_groups == e, (
90+
"Size of `offs` (number of groups) must match first dim of `B_fp8`"
91+
)
92+
output = torch.empty((mg, n), dtype=torch.bfloat16, device=A_fp8.device)
93+
return output
94+
95+
96+
def _log_inputs(
97+
A_fp8: torch.Tensor,
98+
B_fp8: torch.Tensor,
99+
A_scales: torch.Tensor,
100+
A_scales_blocked: torch.Tensor,
101+
B_scales: torch.Tensor,
102+
B_scales_blocked: torch.Tensor,
103+
offs: torch.Tensor,
104+
group_sizes: torch.Tensor,
105+
starting_row_after_padding: torch.Tensor,
106+
):
107+
logger.info(f"offs: {offs}, dtype: {offs.dtype}")
108+
logger.info(
109+
f"A_fp8.shape: {A_fp8.shape}, stride: {A_fp8.stride()}, dtype: {A_fp8.dtype}"
110+
)
111+
logger.info(
112+
f"B_fp8.shape: {B_fp8.shape}, stride: {B_fp8.stride()}, dtype: {B_fp8.dtype}"
113+
)
114+
logger.info(
115+
f"A_scales (non-blocked) shape: {A_scales.shape}, stride: {A_scales.stride()}, dtype: {A_scales.dtype}"
116+
)
117+
logger.info(
118+
f"A_scales_blocked.shape: {A_scales_blocked.shape}, stride: {A_scales_blocked.stride()}, dtype: {A_scales_blocked.dtype}"
119+
)
120+
logger.info(
121+
f"B_scales (non-blocked) shape: {B_scales.shape}, stride: {B_scales.stride()}, dtype: {B_scales.dtype}"
122+
)
123+
logger.info(
124+
f"B_scales_blocked.shape: {B_scales_blocked.shape}, stride: {B_scales_blocked.stride()}, dtype: {B_scales_blocked.dtype}"
125+
)
126+
logger.info(
127+
f"group_sizes: {group_sizes}, stride: {group_sizes.stride()}, dtype: {group_sizes.dtype}"
128+
)
129+
logger.info(
130+
f"starting_row_after_padding: {starting_row_after_padding}, stride: {starting_row_after_padding.stride()}, dtype: {starting_row_after_padding.dtype}"
131+
)

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 50 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
1414
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1515
from torchao.prototype.moe_training.kernels import (
16+
fbgemm_mxfp8_grouped_mm_2d_3d,
1617
triton_fp8_per_group_colwise_scales,
1718
triton_fp8_per_group_rowwise_scales,
1819
triton_fp8_rowwise_3d_transpose_rhs,
@@ -277,52 +278,46 @@ def forward(
277278
offs: Optional[torch.Tensor] = None,
278279
block_size: int = 32,
279280
out_dtype: Optional[torch.dtype] = torch.bfloat16,
280-
emulated: bool = True,
281+
emulated: bool = False,
281282
) -> torch.Tensor:
282283
# torchao _scaled_grouped_mm only supports A=2D and B=3D.
283284
assert A.ndim == 2, "A must be 2D"
284285
assert B_t.ndim == 3, "B must be 3D"
285286
assert block_size == 32, "Only block_size=32 is supported"
286-
assert emulated, "Only emulated mxfp8 grouped gemm is supported"
287+
288+
# Store what we need for backward.
289+
ctx.save_for_backward(A, B_t, offs)
290+
ctx.block_size = block_size
291+
ctx.out_dtype = out_dtype
292+
ctx.emulated = emulated
287293

288294
# Cast to mxpf8 across dim -1.
289295
# A_mx shape: (M, K)
290296
# A_scale shape: (M, K//block_size)
291297
A_scale, A_mx = to_mx(A, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
292298

293-
# Cast B_t per-expert to mxfp8 across dim1.
294-
# B_t_mx shape: (E, K, N)
295-
# B_t_scale shape: (E, K//block_size, N)
296-
297-
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
299+
# Cast B_t per-expert to mxfp8 across K dim.
298300
# B_mx shape: (E, N, K)
299301
# B_scale shape: (E, N, K//block_size)
300-
B_scales_dim2, B_mx_dim2 = to_mx(
301-
B_t.transpose(-2, -1), # (E,K,N) -> (E,N,K)
302+
B_scales, B_mx = to_mx(
303+
B_t.transpose(-2, -1).contiguous(),
302304
elem_dtype=torch.float8_e4m3fn,
303305
block_size=block_size,
304306
)
305307

306-
# B_t_mx shape: (E, K, N)
307-
# B_t_scale shape: (E, K//block_size, N)
308-
B_t_scales_dim1 = B_scales_dim2.transpose(
309-
-2, -1
310-
) # (E,N,K//block_size) -> (E,K//block_size,N)
311-
B_t_mx_dim1 = B_mx_dim2.transpose(-2, -1) # (E,N,K) -> (E,K,N)
312-
313-
# Store what we need for backward.
314-
ctx.save_for_backward(A, B_t, offs)
315-
ctx.block_size = block_size
316-
ctx.out_dtype = out_dtype
317-
318308
# Perform scaled grouped GEMM and return result.
319309
# output = input @ weight.T
320310
# output shape: (M, N)
321-
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
311+
mxfp8_2d_3d_grouped_mm = (
312+
_emulated_mxfp8_scaled_grouped_mm_2d_3d
313+
if emulated
314+
else fbgemm_mxfp8_grouped_mm_2d_3d
315+
)
316+
out = mxfp8_2d_3d_grouped_mm(
322317
A_mx,
323318
A_scale,
324-
B_t_mx_dim1,
325-
B_t_scales_dim1,
319+
B_mx,
320+
B_scales,
326321
offs=offs,
327322
block_size=block_size,
328323
out_dtype=out_dtype,
@@ -334,6 +329,7 @@ def backward(ctx, grad_out: torch.Tensor):
334329
A, B_t, offs = ctx.saved_tensors
335330
block_size = ctx.block_size
336331
out_dtype = ctx.out_dtype
332+
emulated = ctx.emulated
337333

338334
# grad_out_mx shape: (M, N)
339335
# grad_out_scale shape: (M, N//block_size)
@@ -343,23 +339,24 @@ def backward(ctx, grad_out: torch.Tensor):
343339

344340
# B_mx shape: (E, K, N)
345341
# B_scale shape: (E, K, N//block_size)
346-
B_t_scale_dim2, B_t_mx_dim2 = to_mx(
342+
B_scales, B_mx = to_mx(
347343
B_t.contiguous(),
348344
elem_dtype=torch.float8_e4m3fn,
349345
block_size=block_size,
350346
)
351-
B_scale_dim1 = B_t_scale_dim2.transpose(
352-
-2, -1
353-
) # (E,K,N//block_size) -> (E,N//block_size,K)
354-
B_mx_dim1 = B_t_mx_dim2.transpose(-2, -1) # (E,K,N) -> (E,N,K)
355347

356348
# Compute grad_A.
357349
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
358-
grad_A = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
350+
mxfp8_2d_3d_grouped_mm = (
351+
_emulated_mxfp8_scaled_grouped_mm_2d_3d
352+
if emulated
353+
else fbgemm_mxfp8_grouped_mm_2d_3d
354+
)
355+
grad_A = mxfp8_2d_3d_grouped_mm(
359356
grad_out_mx,
360357
grad_out_scale,
361-
B_mx_dim1,
362-
B_scale_dim1,
358+
B_mx,
359+
B_scales,
363360
offs=offs,
364361
out_dtype=out_dtype,
365362
)
@@ -385,7 +382,6 @@ def backward(ctx, grad_out: torch.Tensor):
385382
# Compute grad_B = grad_output_t @ A
386383
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
387384
# grad_B = grad_B_t.transpose(-2, -1) = (E,K,N)
388-
389385
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
390386
grad_out_t_mx,
391387
grad_out_t_scales,
@@ -402,12 +398,30 @@ def backward(ctx, grad_out: torch.Tensor):
402398
def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
403399
A_mx: torch.Tensor,
404400
A_scale: torch.Tensor,
405-
B_t_mx: torch.Tensor,
406-
B_t_scale: torch.Tensor,
401+
B_mx: torch.Tensor,
402+
B_scale: torch.Tensor,
407403
offs: Optional[torch.Tensor] = None,
408404
out_dtype: Optional[torch.dtype] = torch.bfloat16,
409405
block_size: int = 32,
410406
) -> torch.Tensor:
407+
assert A_mx.ndim == 2, f"A must be 2D, got {A_mx.ndim}"
408+
assert B_mx.ndim == 3, f"B must be 3D, got {B_mx.ndim}"
409+
assert A_scale.shape[0] == A_mx.shape[0], (
410+
f"A_scale must have same M dim as A_mx, got A={A_mx.shape} and A_scale={A_scale.shape}"
411+
)
412+
assert A_scale.shape[1] == A_mx.shape[1] // block_size, (
413+
f"A_scale dim1 should be size K//block_size, got A={A_mx.shape} and A_scale={A_scale.shape}"
414+
)
415+
assert B_scale.shape[0] == B_mx.shape[0], (
416+
f"B_scale must have same E dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
417+
)
418+
assert B_scale.shape[1] == B_mx.shape[1], (
419+
f"B_scale must have same N dim as B_mx, got B={B_mx.shape} and B_scale={B_scale.shape}"
420+
)
421+
assert B_scale.shape[2] == B_mx.shape[2] // block_size, (
422+
f"B_scale dim2 should be size K//block_size, got B={B_mx.shape} and B_scale={B_scale.shape}"
423+
)
424+
411425
# Dequantize input
412426
# A_mx shape: (M, K)
413427
# A_scale shape: (M, K//block_size)
@@ -427,14 +441,10 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
427441
A = A.reshape(A_orig_shape)
428442

429443
# Dequantize weights
430-
# B_t_mx shape: (E, K, N)
431-
# B_t_scale shape: (E, K//block_size, N)
432-
E, K, N = B_t_mx.shape
433-
434444
# Tranpose to get block_size on rightmost dim
435445
# B_mx shape: (E, N, K)
436446
# B_scale shape: (E, N, K//block_size)
437-
B_mx, B_scale = B_t_mx.transpose(-2, -1), B_t_scale.transpose(-2, -1)
447+
E, N, K = B_mx.shape
438448

439449
# Reshape to be able to do per-scaling group multiplication
440450
# B_mx shape: (E, N, K//block_size, block_size)

0 commit comments

Comments
 (0)