@@ -230,25 +230,27 @@ def compute_reference_forward(
230230@pytest .mark .parametrize ("num_experts" , (1 , 8 , 16 ))
231231def test_emulate_mxfp8_grouped_gemm_2d_3d (M , K , N , num_experts ):
232232 x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" )
233- w_t = torch .randn (num_experts , K , N , dtype = torch .bfloat16 , device = "cuda" )
233+ w = torch .randn (num_experts , N , K , dtype = torch .bfloat16 , device = "cuda" )
234234 offs = generate_jagged_offs (num_experts , M )
235- x_ref , w_t_ref , offs_ref = x .clone (), w_t .clone (), offs .clone ()
235+ x_ref , w_ref , offs_ref = x .clone (), w .clone (), offs .clone ()
236236
237237 # Quantize inputs to mxpf8 for emulated mxfp8 scaled grouped mm
238238 block_size = 32
239- x_scale , x_mx = to_mx (x , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
239+ x_scale , x_fp8 = to_mx (x , elem_dtype = torch .float8_e4m3fn , block_size = block_size )
240240
241241 # To cast B_t per-expert to mxfp8 across dim1, we transpose the experts, cast along dim -1, then untranspose.
242- w_scale , w_mx = to_mx (
243- w_t . transpose ( - 2 , - 1 ). contiguous () ,
242+ w_scale , w_fp8 = to_mx (
243+ w ,
244244 elem_dtype = torch .float8_e4m3fn ,
245245 block_size = block_size ,
246246 )
247- w_t_scale , w_t_mx = w_scale .transpose (- 2 , - 1 ), w_mx .transpose (- 2 , - 1 )
247+ w_t_scale , w_t_fp8 = w_scale .transpose (- 2 , - 1 ), w_fp8 .transpose (- 2 , - 1 )
248248
249- ref_out = torch ._grouped_mm (x_ref , w_t_ref , offs = offs_ref , out_dtype = torch .bfloat16 )
249+ ref_out = torch ._grouped_mm (
250+ x_ref , w_ref .transpose (- 2 , - 1 ), offs = offs_ref , out_dtype = torch .bfloat16
251+ )
250252 out = _emulated_mxfp8_scaled_grouped_mm_2d_3d (
251- x_mx , x_scale , w_t_mx , w_t_scale , offs = offs , out_dtype = torch .bfloat16
253+ x_fp8 , x_scale , w_t_fp8 , w_t_scale , offs = offs , out_dtype = torch .bfloat16
252254 )
253255
254256 sqnr = compute_error (ref_out , out )
@@ -314,9 +316,14 @@ def test_mxfp8_grouped_gemm_with_dq_fwd_bwd(M, K, N, num_experts):
314316
315317 block_size = 32
316318 x = torch .randn (M , K , dtype = torch .bfloat16 , device = "cuda" , requires_grad = True )
317- w_t = torch .randn (
318- num_experts , K , N , dtype = torch .bfloat16 , device = "cuda" , requires_grad = True
319+ w = torch .randn (
320+ num_experts ,
321+ N ,
322+ K ,
323+ dtype = torch .bfloat16 ,
324+ device = "cuda" ,
319325 )
326+ w_t = w .transpose (- 2 , - 1 ).requires_grad_ (True )
320327 offs = generate_jagged_offs (num_experts , M , multiple_of = block_size )
321328 x_ref , w_t_ref , offs_ref = (
322329 x .clone ().detach ().requires_grad_ (True ),
0 commit comments