Skip to content

Commit a5a29db

Browse files
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm
1 parent b663faf commit a5a29db

File tree

5 files changed

+105
-15
lines changed

5 files changed

+105
-15
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -230,25 +230,27 @@ def compute_reference_forward(
230230
@pytest.mark.parametrize("num_experts", (1, 8, 16))
231231
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
232232
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
233-
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
233+
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
234234
offs = generate_jagged_offs(num_experts, M)
235-
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
235+
x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone()
236236

237237
# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
238238
block_size = 32
239-
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
239+
x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
240240

241241
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
242-
w_scale, w_mx = to_mx(
243-
w_t.transpose(-2, -1).contiguous(),
242+
w_scale, w_fp8 = to_mx(
243+
w,
244244
elem_dtype=torch.float8_e4m3fn,
245245
block_size=block_size,
246246
)
247-
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
247+
w_t_scale, w_t_fp8 = w_scale.transpose(-2, -1), w_fp8.transpose(-2, -1)
248248

249-
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
249+
ref_out = torch._grouped_mm(
250+
x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16
251+
)
250252
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
251-
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
253+
x_fp8, x_scale, w_t_fp8, w_t_scale, offs=offs, out_dtype=torch.bfloat16
252254
)
253255

254256
sqnr = compute_error(ref_out, out)
@@ -314,9 +316,14 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
314316

315317
block_size = 32
316318
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
317-
w_t = torch.randn(
318-
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
319+
w = torch.randn(
320+
num_experts,
321+
N,
322+
K,
323+
dtype=torch.bfloat16,
324+
device="cuda",
319325
)
326+
w_t = w.transpose(-2, -1).requires_grad_(True)
320327
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
321328
x_ref, w_t_ref, offs_ref = (
322329
x.clone().detach().requires_grad_(True),

test/prototype/moe_training/test_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
136136
["does.not.exist"],
137137
],
138138
)
139-
@pytest.mark.parametrize("compile", [False, True])
139+
@pytest.mark.parametrize("compile", [False])
140140
def test_moe_mxfp8_training(target_fqns: list[str], compile: bool):
141141
block_size = 32
142142

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10+
from torchao.prototype.moe_training.kernels.mxfp8 import (
11+
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
12+
)
Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import logging
2+
3+
logger: logging.Logger = logging.getLogger(__name__)
4+
5+
import torch
6+
7+
try:
8+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
9+
except Exception as e:
10+
logging.warning(
11+
f"fbgemm_gpu_genai package is required for this feature but import failed with exception: {e}"
12+
"Please install nightly builds of pytorch and fbgemm_gpu_genai build using this command and try again: "
13+
"pip3 install --force-reinstall --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu129"
14+
"If errors persist, please file a bug report."
15+
)
16+
17+
18+
@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
19+
def fbgemm_mxfp8_grouped_mm_2d_3d(
20+
A_mx: torch.Tensor,
21+
A_scale: torch.Tensor,
22+
B_t_mx_dim1: torch.Tensor,
23+
B_t_scales_dim1: torch.Tensor,
24+
offs: torch.Tensor,
25+
block_size: int = 32,
26+
out_dtype: torch.dtype = torch.bfloat16,
27+
) -> torch.Tensor:
28+
assert A_mx.ndim == 2, "A_mx tensor must be 2D"
29+
assert B_t_mx_dim1.ndim == 3, "B_t_mx_dim1 tensor must be 3D"
30+
assert block_size == 32, "Only block_size=32 is supported"
31+
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
32+
33+
# "A" and "offs" already have been padded so token group sizes along Mg are multiples of scaling block size (32).
34+
# e.g. offs = [32, 64, 128]
35+
# From this, we compute `group_sizes` and `starting_row_after_padding`:
36+
# group_sizes = [32, 32, 64]
37+
# starting_row_after_padding = [0, 32, 64, 128]
38+
zero = torch.tensor([0], dtype=offs.dtype, device=offs.device)
39+
group_sizes = torch.diff(offs, prepend=zero).to(torch.int64)
40+
starting_row_after_paddding = torch.cat((zero, offs))
41+
out = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
42+
A_mx, # (Mg, K)
43+
B_t_mx_dim1, # (E, K, N)
44+
A_scale, # (Mg, K//block_size)
45+
B_t_scales_dim1, # (E, K//block_size, N)
46+
group_sizes, # size of each token group, computed from end idx of each group (`offs`)
47+
starting_row_after_padding=starting_row_after_paddding,
48+
)
49+
return out
50+
51+
52+
@fbgemm_mxfp8_grouped_mm_2d_3d.register_fake
53+
def _fbgemm_mxfp8_grouped_mm_2d_3d_fake(
54+
A_mx: torch.Tensor,
55+
B_t_mx_dim1: torch.Tensor,
56+
A_scale: torch.Tensor,
57+
B_t_scales_dim1: torch.Tensor,
58+
offs: torch.Tensor,
59+
) -> torch.Tensor:
60+
assert A_mx.ndim == 2, "A_mx tensor must be 2D"
61+
assert B_t_mx_dim1.ndim == 3, "B_t_mx_dim1 tensor must be 3D"
62+
mg, k = A_mx.shape
63+
e, k, n = B_t_mx_dim1.shape
64+
n_groups = offs.numel()
65+
assert n_groups == e, (
66+
"Size of `offs` (number of groups) must match first dim of `B_t_mx_dim1`"
67+
)
68+
output = torch.empty((mg, n), dtype=torch.bfloat16, device=A_mx.device)
69+
return output

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
1414
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1515
from torchao.prototype.moe_training.kernels import (
16+
fbgemm_mxfp8_grouped_mm_2d_3d,
1617
triton_fp8_per_group_colwise_scales,
1718
triton_fp8_per_group_rowwise_scales,
1819
triton_fp8_rowwise_3d_transpose_rhs,
@@ -283,7 +284,6 @@ def forward(
283284
assert A.ndim == 2, "A must be 2D"
284285
assert B_t.ndim == 3, "B must be 3D"
285286
assert block_size == 32, "Only block_size=32 is supported"
286-
assert emulated, "Only emulated mxfp8 grouped gemm is supported"
287287

288288
# Cast to mxpf8 across dim -1.
289289
# A_mx shape: (M, K)
@@ -314,11 +314,17 @@ def forward(
314314
ctx.save_for_backward(A, B_t, offs)
315315
ctx.block_size = block_size
316316
ctx.out_dtype = out_dtype
317+
ctx.emulated = emulated
317318

318319
# Perform scaled grouped GEMM and return result.
319320
# output = input @ weight.T
320321
# output shape: (M, N)
321-
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
322+
mxfp8_2d_3d_grouped_mm = (
323+
_emulated_mxfp8_scaled_grouped_mm_2d_3d
324+
if emulated
325+
else fbgemm_mxfp8_grouped_mm_2d_3d
326+
)
327+
out = mxfp8_2d_3d_grouped_mm(
322328
A_mx,
323329
A_scale,
324330
B_t_mx_dim1,
@@ -334,6 +340,7 @@ def backward(ctx, grad_out: torch.Tensor):
334340
A, B_t, offs = ctx.saved_tensors
335341
block_size = ctx.block_size
336342
out_dtype = ctx.out_dtype
343+
emulated = ctx.emulated
337344

338345
# grad_out_mx shape: (M, N)
339346
# grad_out_scale shape: (M, N//block_size)
@@ -355,7 +362,12 @@ def backward(ctx, grad_out: torch.Tensor):
355362

356363
# Compute grad_A.
357364
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
358-
grad_A = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
365+
mxfp8_2d_3d_grouped_mm = (
366+
_emulated_mxfp8_scaled_grouped_mm_2d_3d
367+
if emulated
368+
else fbgemm_mxfp8_grouped_mm_2d_3d
369+
)
370+
grad_A = mxfp8_2d_3d_grouped_mm(
359371
grad_out_mx,
360372
grad_out_scale,
361373
B_mx_dim1,
@@ -385,7 +397,6 @@ def backward(ctx, grad_out: torch.Tensor):
385397
# Compute grad_B = grad_output_t @ A
386398
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
387399
# grad_B = grad_B_t.transpose(-2, -1) = (E,K,N)
388-
389400
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
390401
grad_out_t_mx,
391402
grad_out_t_scales,

0 commit comments

Comments
 (0)