1313from torchao .float8 .float8_utils import tensor_to_scale , to_fp8_saturated
1414from torchao .prototype .moe_training .conversion_utils import MoEScalingType
1515from torchao .prototype .moe_training .kernels import (
16+ fbgemm_mxfp8_grouped_mm_2d_3d ,
1617 triton_fp8_per_group_colwise_scales ,
1718 triton_fp8_per_group_rowwise_scales ,
1819 triton_fp8_rowwise_3d_transpose_rhs ,
@@ -277,52 +278,46 @@ def forward(
277278 offs : Optional [torch .Tensor ] = None ,
278279 block_size : int = 32 ,
279280 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
280- emulated : bool = True ,
281+ emulated : bool = False ,
281282 ) -> torch .Tensor :
282283 # torchao _scaled_grouped_mm only supports A=2D and B=3D.
283284 assert A .ndim == 2 , "A must be 2D"
284285 assert B_t .ndim == 3 , "B must be 3D"
285286 assert block_size == 32 , "Only block_size=32 is supported"
286- assert emulated , "Only emulated mxfp8 grouped gemm is supported"
287+
288+ # Store what we need for backward.
289+ ctx .save_for_backward (A , B_t , offs )
290+ ctx .block_size = block_size
291+ ctx .out_dtype = out_dtype
292+ ctx .emulated = emulated
287293
288294 # Cast to mxpf8 across dim -1.
289295 # A_mx shape: (M, K)
290296 # A_scale shape: (M, K//block_size)
291297 A_scale , A_mx = to_mx (A , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
292298
293- # Cast B_t per-expert to mxfp8 across dim1.
294- # B_t_mx shape: (E, K, N)
295- # B_t_scale shape: (E, K//block_size, N)
296-
297- # To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
299+ # Cast B_t per-expert to mxfp8 across K dim.
298300 # B_mx shape: (E, N, K)
299301 # B_scale shape: (E, N, K//block_size)
300- B_scales_dim2 , B_mx_dim2 = to_mx (
301- B_t .transpose (- 2 , - 1 ), # (E,K,N) -> (E,N,K)
302+ B_scales , B_mx = to_mx (
303+ B_t .transpose (- 2 , - 1 ). contiguous (),
302304 elem_dtype = torch .float8_e4m3fn ,
303305 block_size = block_size ,
304306 )
305307
306- # B_t_mx shape: (E, K, N)
307- # B_t_scale shape: (E, K//block_size, N)
308- B_t_scales_dim1 = B_scales_dim2 .transpose (
309- - 2 , - 1
310- ) # (E,N,K//block_size) -> (E,K//block_size,N)
311- B_t_mx_dim1 = B_mx_dim2 .transpose (- 2 , - 1 ) # (E,N,K) -> (E,K,N)
312-
313- # Store what we need for backward.
314- ctx .save_for_backward (A , B_t , offs )
315- ctx .block_size = block_size
316- ctx .out_dtype = out_dtype
317-
318308 # Perform scaled grouped GEMM and return result.
319309 # output = input @ weight.T
320310 # output shape: (M, N)
321- out = _emulated_mxfp8_scaled_grouped_mm_2d_3d (
311+ mxfp8_2d_3d_grouped_mm = (
312+ _emulated_mxfp8_scaled_grouped_mm_2d_3d
313+ if emulated
314+ else fbgemm_mxfp8_grouped_mm_2d_3d
315+ )
316+ out = mxfp8_2d_3d_grouped_mm (
322317 A_mx ,
323318 A_scale ,
324- B_t_mx_dim1 ,
325- B_t_scales_dim1 ,
319+ B_mx ,
320+ B_scales ,
326321 offs = offs ,
327322 block_size = block_size ,
328323 out_dtype = out_dtype ,
@@ -334,6 +329,7 @@ def backward(ctx, grad_out: torch.Tensor):
334329 A , B_t , offs = ctx .saved_tensors
335330 block_size = ctx .block_size
336331 out_dtype = ctx .out_dtype
332+ emulated = ctx .emulated
337333
338334 # grad_out_mx shape: (M, N)
339335 # grad_out_scale shape: (M, N//block_size)
@@ -343,23 +339,24 @@ def backward(ctx, grad_out: torch.Tensor):
343339
344340 # B_mx shape: (E, K, N)
345341 # B_scale shape: (E, K, N//block_size)
346- B_t_scale_dim2 , B_t_mx_dim2 = to_mx (
342+ B_scales , B_mx = to_mx (
347343 B_t .contiguous (),
348344 elem_dtype = torch .float8_e4m3fn ,
349345 block_size = block_size ,
350346 )
351- B_scale_dim1 = B_t_scale_dim2 .transpose (
352- - 2 , - 1
353- ) # (E,K,N//block_size) -> (E,N//block_size,K)
354- B_mx_dim1 = B_t_mx_dim2 .transpose (- 2 , - 1 ) # (E,K,N) -> (E,N,K)
355347
356348 # Compute grad_A.
357349 # grad_A = scaled grouped mm of (M,N) @ (B,N,K) = (M,K)
358- grad_A = _emulated_mxfp8_scaled_grouped_mm_2d_3d (
350+ mxfp8_2d_3d_grouped_mm = (
351+ _emulated_mxfp8_scaled_grouped_mm_2d_3d
352+ if emulated
353+ else fbgemm_mxfp8_grouped_mm_2d_3d
354+ )
355+ grad_A = mxfp8_2d_3d_grouped_mm (
359356 grad_out_mx ,
360357 grad_out_scale ,
361- B_mx_dim1 ,
362- B_scale_dim1 ,
358+ B_mx ,
359+ B_scales ,
363360 offs = offs ,
364361 out_dtype = out_dtype ,
365362 )
@@ -385,7 +382,6 @@ def backward(ctx, grad_out: torch.Tensor):
385382 # Compute grad_B = grad_output_t @ A
386383 # grad_B_t = scaled grouped mm of (N,M) @ (M,K) = (E,N,K)
387384 # grad_B = grad_B_t.transpose(-2, -1) = (E,K,N)
388-
389385 grad_B = _emulated_mxfp8_scaled_grouped_mm_2d_2d (
390386 grad_out_t_mx ,
391387 grad_out_t_scales ,
@@ -402,12 +398,30 @@ def backward(ctx, grad_out: torch.Tensor):
402398def _emulated_mxfp8_scaled_grouped_mm_2d_3d (
403399 A_mx : torch .Tensor ,
404400 A_scale : torch .Tensor ,
405- B_t_mx : torch .Tensor ,
406- B_t_scale : torch .Tensor ,
401+ B_mx : torch .Tensor ,
402+ B_scale : torch .Tensor ,
407403 offs : Optional [torch .Tensor ] = None ,
408404 out_dtype : Optional [torch .dtype ] = torch .bfloat16 ,
409405 block_size : int = 32 ,
410406) -> torch .Tensor :
407+ assert A_mx .ndim == 2 , f"A must be 2D, got { A_mx .ndim } "
408+ assert B_mx .ndim == 3 , f"B must be 3D, got { B_mx .ndim } "
409+ assert A_scale .shape [0 ] == A_mx .shape [0 ], (
410+ f"A_scale must have same M dim as A_mx, got A={ A_mx .shape } and A_scale={ A_scale .shape } "
411+ )
412+ assert A_scale .shape [1 ] == A_mx .shape [1 ] // block_size , (
413+ f"A_scale dim1 should be size K//block_size, got A={ A_mx .shape } and A_scale={ A_scale .shape } "
414+ )
415+ assert B_scale .shape [0 ] == B_mx .shape [0 ], (
416+ f"B_scale must have same E dim as B_mx, got B={ B_mx .shape } and B_scale={ B_scale .shape } "
417+ )
418+ assert B_scale .shape [1 ] == B_mx .shape [1 ], (
419+ f"B_scale must have same N dim as B_mx, got B={ B_mx .shape } and B_scale={ B_scale .shape } "
420+ )
421+ assert B_scale .shape [2 ] == B_mx .shape [2 ] // block_size , (
422+ f"B_scale dim2 should be size K//block_size, got B={ B_mx .shape } and B_scale={ B_scale .shape } "
423+ )
424+
411425 # Dequantize input
412426 # A_mx shape: (M, K)
413427 # A_scale shape: (M, K//block_size)
@@ -427,14 +441,10 @@ def _emulated_mxfp8_scaled_grouped_mm_2d_3d(
427441 A = A .reshape (A_orig_shape )
428442
429443 # Dequantize weights
430- # B_t_mx shape: (E, K, N)
431- # B_t_scale shape: (E, K//block_size, N)
432- E , K , N = B_t_mx .shape
433-
434444 # Tranpose to get block_size on rightmost dim
435445 # B_mx shape: (E, N, K)
436446 # B_scale shape: (E, N, K//block_size)
437- B_mx , B_scale = B_t_mx . transpose ( - 2 , - 1 ), B_t_scale . transpose ( - 2 , - 1 )
447+ E , N , K = B_mx . shape
438448
439449 # Reshape to be able to do per-scaling group multiplication
440450 # B_mx shape: (E, N, K//block_size, block_size)
0 commit comments