@@ -297,21 +297,53 @@ def test_fp8_weight_dimension_warning(self):
297297 @unittest .skipIf (
298298 not is_sm_at_least_89 (), "Requires GPU with compute capability >= 8.9"
299299 )
300- def test_mm_float8dq (self ):
300+ @common_utils .parametrize (
301+ "in_features,out_features" , [(512 , 1024 ), (256 , 768 ), (1024 , 512 )]
302+ )
303+ @common_utils .parametrize (
304+ "leading_shape" , [(1 ,), (8 ,), (16 ,), (2 , 8 ,), (2 , 2 , 16 ,)]
305+ ) # fmt: skip
306+ @common_utils .parametrize ("bias" , [True , False ])
307+ def test_mm_float8dq (self , in_features , out_features , leading_shape , bias : bool ):
301308 device = "cuda"
302309 dtype = torch .bfloat16
303- weight = torch .randn (512 , 1024 ).to (device ).to (dtype )
304- weight = weight .t ()
305-
306- l = torch .nn .Linear (512 , 1024 ).to (device ).to (dtype )
307- l .weight = torch .nn .Parameter (weight )
308- quantize_ (l , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ()))
309- # weight shape: 1024 x 512
310- weight = l .weight
311-
312- input = torch .randn (1 , 512 , device = device , dtype = dtype )
313- # make sure it runs
314- torch .nn .functional .linear (input , weight )
310+ input_shape = leading_shape + (in_features ,)
311+
312+ ref_linear = (
313+ torch .nn .Linear (in_features , out_features , bias = bias ).to (device ).to (dtype )
314+ )
315+ test_linear = copy .deepcopy (ref_linear )
316+ quantize_ (
317+ test_linear , Float8DynamicActivationFloat8WeightConfig (granularity = PerRow ())
318+ )
319+
320+ quant_weight = test_linear .weight
321+
322+ self .assertTrue (hasattr (quant_weight , "original_weight_tensor" ))
323+ weight_impl = quant_weight .original_weight_tensor .tensor_impl
324+
325+ self .assertTrue (hasattr (weight_impl , "float8_data" ))
326+ self .assertTrue (hasattr (weight_impl , "scale" ))
327+ self .assertFalse (weight_impl .transposed )
328+
329+ # Verify scale shape for row-wise quantization
330+ expected_scale_shape = (out_features , 1 )
331+ actual_scale_shape = weight_impl .scale .shape
332+ self .assertEqual (actual_scale_shape , expected_scale_shape )
333+
334+ self .assertEqual (weight_impl .float8_data .shape , (out_features , in_features ))
335+
336+ input_tensor = torch .randn (* input_shape , device = device , dtype = dtype )
337+
338+ with torch .no_grad ():
339+ ref_output = ref_linear (input_tensor )
340+ quant_output = torch .nn .functional .linear (input_tensor , quant_weight )
341+
342+ expected_output_shape = input_tensor .shape [:- 1 ] + (out_features ,)
343+ self .assertEqual (quant_output .shape , expected_output_shape )
344+
345+ error = compute_error (ref_output , quant_output )
346+ assert error > 20 , f"Quantization error is too high got a SQNR of { error } "
315347
316348
317349common_utils .instantiate_parametrized_tests (TestAffineQuantizedFloat8Compile )
0 commit comments