@@ -297,21 +297,66 @@ def test_fp8_weight_dimension_warning(self):
297297 @unittest .skipIf (
298298 not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
299299 )
300- def test_mm_float8dq (self ):
300+ @common_utils .parametrize (
301+ "in_features,out_features" , [(512 , 1024 ), (256 , 768 ), (1024 , 512 )]
302+ )
303+ @common_utils .parametrize (
304+ "input_shape" , [(1 , 512 ), (8 , 512 ), (16 , 512 ), (2 , 8 , 512 ), (2 , 2 , 16 , 512 )]
305+ )
306+ def test_mm_float8dq (self , in_features , out_features , input_shape ):
301307 device = "cuda"
302308 dtype = torch .bfloat16
303- weight = torch .randn (512 , 1024 ).to (device ).to (dtype )
309+
310+ # Adjust input shape to match in_features
311+ input_shape = list (input_shape )
312+ input_shape [- 1 ] = in_features
313+
314+ weight = torch .randn (in_features , out_features ).to (device ).to (dtype )
304315 weight = weight .t ()
305316
306- l = torch .nn .Linear (512 , 1024 ).to (device ).to (dtype )
307- l .weight = torch .nn .Parameter (weight )
308- quantize_ (l , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()))
309- # weight shape: 1024 x 512
310- weight = l .weight
317+ ref_linear = (
318+ torch .nn .Linear (in_features , out_features , bias = False ).to (device ).to (dtype )
319+ )
320+ ref_linear .weight = torch .nn .Parameter (weight .clone ())
321+
322+ test_linear = (
323+ torch .nn .Linear (in_features , out_features , bias = False ).to (device ).to (dtype )
324+ )
325+ test_linear .weight = torch .nn .Parameter (weight .clone ())
326+ quantize_ (
327+ test_linear , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
328+ )
329+
330+ quant_weight = test_linear .weight
331+
332+ self .assertTrue (hasattr (quant_weight , "original_weight_tensor" ))
333+ weight_impl = quant_weight .original_weight_tensor .tensor_impl
334+
335+ self .assertTrue (hasattr (weight_impl , "float8_data" ))
336+ self .assertTrue (hasattr (weight_impl , "scale" ))
337+ self .assertFalse (weight_impl .transposed )
338+
339+ # Verify scale shape for row-wise quantization
340+ expected_scale_shape = (out_features , 1 )
341+ actual_scale_shape = weight_impl .scale .shape
342+ self .assertEqual (actual_scale_shape , expected_scale_shape )
343+
344+ self .assertEqual (weight_impl .float8_data .shape , (out_features , in_features ))
345+
346+ input_tensor = torch .randn (* input_shape , device = device , dtype = dtype )
347+
348+ with torch .no_grad ():
349+ ref_output = ref_linear (input_tensor )
350+ quant_output = torch .nn .functional .linear (input_tensor , quant_weight )
351+
352+ expected_output_shape = input_tensor .shape [:- 1 ] + (out_features ,)
353+ self .assertEqual (quant_output .shape , expected_output_shape )
354+
355+ max_abs_error = (ref_output - quant_output ).abs ().max ().item ()
356+ ref_max = ref_output .abs ().max ().item ()
357+ relative_error = max_abs_error / ref_max if ref_max > 0 else 0
311358
312- input = torch .randn (1 , 512 , device = device , dtype = dtype )
313- # make sure it runs
314- torch .nn .functional .linear (input , weight )
359+ self .assertLess (relative_error , 0.05 )
315360
316361
317362common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
0 commit comments