Skip to content

Commit 24ac553

Browse files
[mxfp8 moe] add support for fbgemm 2d-3d mx8mx8bf16 grouped gemm
stack-info: PR: #2848, branch: danielvegamyhre/stack/55
1 parent b663faf commit 24ac553

File tree

6 files changed

+209
-19
lines changed

6 files changed

+209
-19
lines changed

test/prototype/moe_training/test_scaled_grouped_mm.py

Lines changed: 22 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -230,25 +230,27 @@ def compute_reference_forward(
230230
@pytest.mark.parametrize("num_experts", (1, 8, 16))
231231
def test_emulate_mxfp8_grouped_gemm_2d_3d(M, K, N, num_experts):
232232
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda")
233-
w_t = torch.randn(num_experts, K, N, dtype=torch.bfloat16, device="cuda")
233+
w = torch.randn(num_experts, N, K, dtype=torch.bfloat16, device="cuda")
234234
offs = generate_jagged_offs(num_experts, M)
235-
x_ref, w_t_ref, offs_ref = x.clone(), w_t.clone(), offs.clone()
235+
x_ref, w_ref, offs_ref = x.clone(), w.clone(), offs.clone()
236236

237237
# Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
238238
block_size = 32
239-
x_scale, x_mx = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
239+
x_scale, x_fp8 = to_mx(x, elem_dtype=torch.float8_e4m3fn, block_size=block_size)
240240

241241
# To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
242-
w_scale, w_mx = to_mx(
243-
w_t.transpose(-2, -1).contiguous(),
242+
w_scale, w_fp8 = to_mx(
243+
w,
244244
elem_dtype=torch.float8_e4m3fn,
245245
block_size=block_size,
246246
)
247-
w_t_scale, w_t_mx = w_scale.transpose(-2, -1), w_mx.transpose(-2, -1)
247+
w_t_scale, w_t_fp8 = w_scale.transpose(-2, -1), w_fp8.transpose(-2, -1)
248248

249-
ref_out = torch._grouped_mm(x_ref, w_t_ref, offs=offs_ref, out_dtype=torch.bfloat16)
249+
ref_out = torch._grouped_mm(
250+
x_ref, w_ref.transpose(-2, -1), offs=offs_ref, out_dtype=torch.bfloat16
251+
)
250252
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
251-
x_mx, x_scale, w_t_mx, w_t_scale, offs=offs, out_dtype=torch.bfloat16
253+
x_fp8, x_scale, w_t_fp8, w_t_scale, offs=offs, out_dtype=torch.bfloat16
252254
)
253255

254256
sqnr = compute_error(ref_out, out)
@@ -305,19 +307,26 @@ def test_emulate_mxfp8_grouped_gemm_2d_2d(M, N, num_experts):
305307

306308

307309
@skip_if_rocm("ROCm not supported")
308-
@pytest.mark.parametrize("M,K,N", [(1024, 1024, 1024), (1024, 2048, 4096)])
309-
@pytest.mark.parametrize("num_experts", (1, 8, 16))
310+
@pytest.mark.parametrize("M,K,N", [(256, 512, 512)])
311+
@pytest.mark.parametrize("num_experts", (2,))
310312
def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
311313
from torchao.prototype.moe_training.scaled_grouped_mm import (
312314
_MXFP8GroupedMM,
313315
)
314316

315317
block_size = 32
316318
x = torch.randn(M, K, dtype=torch.bfloat16, device="cuda", requires_grad=True)
317-
w_t = torch.randn(
318-
num_experts, K, N, dtype=torch.bfloat16, device="cuda", requires_grad=True
319+
w = torch.randn(
320+
num_experts,
321+
N,
322+
K,
323+
dtype=torch.bfloat16,
324+
device="cuda",
319325
)
320-
offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
326+
w_t = w.transpose(-2, -1).requires_grad_(True)
327+
# offs = generate_jagged_offs(num_experts, M, multiple_of=block_size)
328+
group_size = M // num_experts
329+
offs = torch.arange(group_size, M + 1, group_size, device="cuda", dtype=torch.int32)
321330
x_ref, w_t_ref, offs_ref = (
322331
x.clone().detach().requires_grad_(True),
323332
w_t.clone().detach().requires_grad_(True),

test/prototype/moe_training/test_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ def moe_module_filter_fn(mod: nn.Module, cur_fqn: str) -> bool:
136136
["does.not.exist"],
137137
],
138138
)
139-
@pytest.mark.parametrize("compile", [False, True])
139+
@pytest.mark.parametrize("compile", [False])
140140
def test_moe_mxfp8_training(target_fqns: list[str], compile: bool):
141141
block_size = 32
142142

torchao/prototype/moe_training/kernels/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,6 @@
77
from torchao.prototype.moe_training.kernels.jagged_float8_scales import (
88
triton_fp8_per_group_rowwise_scales as triton_fp8_per_group_rowwise_scales,
99
)
10+
from torchao.prototype.moe_training.kernels.mxfp8 import (
11+
fbgemm_mxfp8_grouped_mm_2d_3d as fbgemm_mxfp8_grouped_mm_2d_3d,
12+
)
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import logging
2+
3+
logger: logging.Logger = logging.getLogger(__name__)
4+
5+
import torch
6+
7+
from torchao.prototype.mx_formats.utils import (
8+
to_blocked_per_group_2d,
9+
to_blocked_per_group_3d,
10+
)
11+
12+
try:
13+
import fbgemm_gpu.experimental.gen_ai # noqa: F401
14+
except Exception as e:
15+
logging.warning(
16+
f"fbgemm_gpu_genai package is required for this feature but import failed with exception: {e}"
17+
"Please install nightly builds of pytorch and fbgemm_gpu_genai build using this command and try again: "
18+
"pip3 install --force-reinstall --pre torch fbgemm-gpu-genai --index-url https://download.pytorch.org/whl/nightly/cu129"
19+
"If errors persist, please file a bug report."
20+
)
21+
22+
23+
@torch.library.custom_op("torchao::fbgemm_mxfp8_grouped_mm_2d_3d", mutates_args={})
24+
def fbgemm_mxfp8_grouped_mm_2d_3d(
25+
A_mx: torch.Tensor,
26+
A_scale: torch.Tensor,
27+
B_t_mx_dim1: torch.Tensor,
28+
B_t_scales_dim1: torch.Tensor,
29+
offs: torch.Tensor,
30+
block_size: int = 32,
31+
out_dtype: torch.dtype = torch.bfloat16,
32+
) -> torch.Tensor:
33+
assert A_mx.ndim == 2, "A_mx tensor must be 2D"
34+
assert B_t_mx_dim1.ndim == 3, "B_t_mx_dim1 tensor must be 3D"
35+
assert block_size == 32, "Only block_size=32 is supported"
36+
assert out_dtype == torch.bfloat16, "Only out_dtype=bfloat16 is supported"
37+
38+
# Convert scales for each group to blocked format.
39+
Mg, K = A_mx.shape
40+
A_scale_blocked, starting_row_after_padding = to_blocked_per_group_2d(A_scale, offs, Mg, K)
41+
B_t_scales_dim1_blocked = to_blocked_per_group_3d(B_t_scales_dim1)
42+
43+
# From this, we compute `group_sizes` and `starting_row_after_padding`:
44+
# group_sizes = [32, 32, 64]
45+
# starting_row_after_padding = [0, 32, 64, 128]
46+
group_sizes = torch.diff(starting_row_after_padding).to(torch.int64)
47+
print()
48+
print("A_mx.shape", A_mx.shape)
49+
print("A_scales.shape", A_scale.shape)
50+
print("A_scales_blocked.shape", A_scale_blocked.shape)
51+
print("B_t_mx_dim1.shape", B_t_mx_dim1.shape)
52+
print("B_t_scales_dim1.shape", B_t_scales_dim1.shape)
53+
print("B_t_scales_dim1_blocked.shape", B_t_scales_dim1_blocked.shape)
54+
print("group_sizes", group_sizes)
55+
print("starting_row_after_padding", starting_row_after_padding)
56+
out = torch.ops.fbgemm.mx8mx8bf16_grouped_stacked(
57+
A_mx,
58+
B_t_mx_dim1,
59+
A_scale_blocked,
60+
B_t_scales_dim1_blocked,
61+
group_sizes,
62+
starting_row_after_padding=starting_row_after_padding,
63+
)
64+
return out
65+
66+
67+
@fbgemm_mxfp8_grouped_mm_2d_3d.register_fake
68+
def _fbgemm_mxfp8_grouped_mm_2d_3d_fake(
69+
A_mx: torch.Tensor,
70+
B_t_mx_dim1: torch.Tensor,
71+
A_scale: torch.Tensor,
72+
B_t_scales_dim1: torch.Tensor,
73+
offs: torch.Tensor,
74+
) -> torch.Tensor:
75+
assert A_mx.ndim == 2, "A_mx tensor must be 2D"
76+
assert B_t_mx_dim1.ndim == 3, "B_t_mx_dim1 tensor must be 3D"
77+
mg, k = A_mx.shape
78+
e, k, n = B_t_mx_dim1.shape
79+
n_groups = offs.numel()
80+
assert n_groups == e, (
81+
"Size of `offs` (number of groups) must match first dim of `B_t_mx_dim1`"
82+
)
83+
output = torch.empty((mg, n), dtype=torch.bfloat16, device=A_mx.device)
84+
return output

torchao/prototype/moe_training/scaled_grouped_mm.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from torchao.float8.float8_utils import tensor_to_scale, to_fp8_saturated
1414
from torchao.prototype.moe_training.conversion_utils import MoEScalingType
1515
from torchao.prototype.moe_training.kernels import (
16+
fbgemm_mxfp8_grouped_mm_2d_3d,
1617
triton_fp8_per_group_colwise_scales,
1718
triton_fp8_per_group_rowwise_scales,
1819
triton_fp8_rowwise_3d_transpose_rhs,
@@ -277,13 +278,13 @@ def forward(
277278
offs: Optional[torch.Tensor] = None,
278279
block_size: int = 32,
279280
out_dtype: Optional[torch.dtype] = torch.bfloat16,
280-
emulated: bool = True,
281+
emulated: bool = False,
281282
) -> torch.Tensor:
282283
# torchao _scaled_grouped_mm only supports A=2D and B=3D.
283284
assert A.ndim == 2, "A must be 2D"
284285
assert B_t.ndim == 3, "B must be 3D"
285286
assert block_size == 32, "Only block_size=32 is supported"
286-
assert emulated, "Only emulated mxfp8 grouped gemm is supported"
287+
# assert emulated, "Only emulated mxfp8 grouped gemm is supported"
287288

288289
# Cast to mxpf8 across dim -1.
289290
# A_mx shape: (M, K)
@@ -314,11 +315,17 @@ def forward(
314315
ctx.save_for_backward(A, B_t, offs)
315316
ctx.block_size = block_size
316317
ctx.out_dtype = out_dtype
318+
ctx.emulated = emulated
317319

318320
# Perform scaled grouped GEMM and return result.
319321
# output = input @ weight.T
320322
# output shape: (M, N)
321-
out = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
323+
mxfp8_2d_3d_grouped_mm = (
324+
_emulated_mxfp8_scaled_grouped_mm_2d_3d
325+
if emulated
326+
else fbgemm_mxfp8_grouped_mm_2d_3d
327+
)
328+
out = mxfp8_2d_3d_grouped_mm(
322329
A_mx,
323330
A_scale,
324331
B_t_mx_dim1,
@@ -334,6 +341,7 @@ def backward(ctx, grad_out: torch.Tensor):
334341
A, B_t, offs = ctx.saved_tensors
335342
block_size = ctx.block_size
336343
out_dtype = ctx.out_dtype
344+
emulated = ctx.emulated
337345

338346
# grad_out_mx shape: (M, N)
339347
# grad_out_scale shape: (M, N//block_size)
@@ -355,7 +363,12 @@ def backward(ctx, grad_out: torch.Tensor):
355363

356364
# Compute grad_A.
357365
# grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
358-
grad_A = _emulated_mxfp8_scaled_grouped_mm_2d_3d(
366+
mxfp8_2d_3d_grouped_mm = (
367+
_emulated_mxfp8_scaled_grouped_mm_2d_3d
368+
if emulated
369+
else fbgemm_mxfp8_grouped_mm_2d_3d
370+
)
371+
grad_A = mxfp8_2d_3d_grouped_mm(
359372
grad_out_mx,
360373
grad_out_scale,
361374
B_mx_dim1,
@@ -385,7 +398,6 @@ def backward(ctx, grad_out: torch.Tensor):
385398
# Compute grad_B = grad_output_t @ A
386399
# grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
387400
# grad_B = grad_B_t.transpose(-2, -1) = (E,K,N)
388-
389401
grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d(
390402
grad_out_t_mx,
391403
grad_out_t_scales,

torchao/prototype/mx_formats/utils.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,85 @@ def _to_blocked_single(scales: Tensor) -> Tensor:
9999
assert scales.shape == (128, 4)
100100
scales_tiled = scales.view(4, 32, 4) # view as 4 - (32, 4) tiles
101101
return scales_tiled.transpose(0, 1).reshape(32, 16) # Interleave tiles
102+
103+
104+
def to_blocked_per_group_2d(
105+
x_scales: Tensor,
106+
group_offs: Tensor,
107+
Mg: int,
108+
K: int,
109+
block_size: int = 32
110+
) -> Tensor:
111+
"""
112+
Convert scales to blocked format for a 2D tensor (input activations / token groups)
113+
114+
Args:
115+
x_scales: Tensor with per group scales in blocked format concatenated into one tensor.
116+
group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the Mg dimension.
117+
Mg: total size of all groups summed together
118+
K: K dim size
119+
120+
Returns:
121+
blocked_scales: Tensor
122+
start_row_after_padding: Tensor of shape (num_groups,) which contains the start row after padding for each group.
123+
"""
124+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked
125+
assert x_scales.ndim == 2, "x_scales must be 2D"
126+
assert block_size == 32, "Only block_size=32 is supported for now"
127+
blocked_scales_list = []
128+
start_row_after_padding_list = [0]
129+
group_start_idx = 0
130+
for i, group_end_idx in enumerate(group_offs.tolist()):
131+
group_size = group_end_idx - group_start_idx
132+
prev_start_row_after_padding = start_row_after_padding_list[i]
133+
if group_size == 0:
134+
start_row_after_padding_list.append(prev_start_row_after_padding)
135+
continue
136+
137+
# Convert group scales to blocked format
138+
group_scales = x_scales[group_start_idx:group_end_idx]
139+
print("group scales before blocking: ", group_scales.shape)
140+
group_scales_blocked = _to_blocked(group_scales)
141+
print("group_scales after blocking: ", group_scales_blocked.shape)
142+
blocked_scales_list.append(group_scales_blocked)
143+
144+
# Calculate the start row after padding
145+
scaling_groups_per_row = K // block_size
146+
rows_for_group = group_scales_blocked.numel() // scaling_groups_per_row
147+
new_start_row = prev_start_row_after_padding + rows_for_group
148+
start_row_after_padding_list.append(new_start_row)
149+
150+
# Update next group start index
151+
group_start_idx = group_end_idx
152+
153+
blocked_scales = torch.cat(blocked_scales_list, dim=0).contiguous()
154+
print("blocked x_scales before reshape", blocked_scales.shape)
155+
blocked_scales = blocked_scales.reshape(-1, K // 32)
156+
start_row_after_padding = torch.tensor(
157+
start_row_after_padding_list, device=x_scales.device, dtype=torch.int32
158+
)
159+
return blocked_scales, start_row_after_padding
160+
161+
162+
def to_blocked_per_group_3d(weight_scales: Tensor) -> Tensor:
163+
"""
164+
Convert scales to blocked format for each group for a 3D tensor (expert weights)
165+
166+
Args:
167+
scales: Tensor of shape (E, N, K//block_size)
168+
group_offs: Tensor of shape (num_groups,) which contains the end index of each group along the
169+
"""
170+
from fbgemm_gpu.experimental.gemm.triton_gemm.fp4_quantize import _to_blocked
171+
blocked_scales_list = []
172+
num_groups = weight_scales.shape[0]
173+
print("weight scales shape original: ", weight_scales.shape)
174+
for i in range(num_groups):
175+
group_scales = weight_scales[i].t()
176+
print("group_scales original shape: ", group_scales.shape)
177+
group_scales_blocked = _to_blocked(group_scales)
178+
print("group_scales after blocking: ", group_scales_blocked.shape, "i: ", i, "num_groups: ",)
179+
blocked_scales_list.append(group_scales_blocked)
180+
print("blocked_scales_list shapes", [s.shape for s in blocked_scales_list])
181+
weight_scales_blocked = torch.stack(blocked_scales_list, dim=0)
182+
weight_scales_blocked = weight_scales_blocked.reshape(num_groups, -1).contiguous()
183+
return weight_scales_blocked

0 commit comments

Comments
 (0)