1010import unittest
1111
1212import torch
13- from parameterized import param , parameterized
1413from torch .testing import FileCheck
14+ from torch .testing ._internal .common_utils import (
15+ TestCase ,
16+ instantiate_parametrized_tests ,
17+ parametrize ,
18+ )
1519
1620from torchao .dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout , QDQLayout
1721from torchao .quantization .granularity import PerAxis , PerGroup
3438
3539
3640@unittest .skipIf (not _is_kernel_library_loaded (), "Kernel library not loaded" )
37- class TestInt8DynamicActivationIntxWeight (unittest .TestCase ):
38- TEST_ACCURACY_CASES = [
39- param (
40- layout = layout ,
41- weight_dtype = weight_dtype ,
42- weight_mapping_type = weight_mapping_type ,
43- weight_granularity = weight_granularity ,
44- )
45- for layout in [
46- PackedLinearInt8DynamicActivationIntxWeightLayout (),
47- PackedLinearInt8DynamicActivationIntxWeightLayout (target = "universal" ),
48- ]
49- for weight_dtype in [
50- torch .int1 ,
51- torch .int2 ,
52- torch .int3 ,
53- torch .int4 ,
54- torch .int5 ,
55- torch .int6 ,
56- torch .int7 ,
57- torch .int8 ,
58- ]
59- for weight_mapping_type in [
60- MappingType .SYMMETRIC ,
61- MappingType .ASYMMETRIC ,
62- MappingType .SYMMETRIC_NO_CLIPPING_ERR ,
63- ]
64- for weight_granularity in [
65- PerGroup (128 ),
66- PerAxis (0 ),
67- ]
68- ]
69-
70- @parameterized .expand (
71- TEST_ACCURACY_CASES ,
72- name_func = lambda f , _ , params : f .__name__ + f"_{ params .kwargs } " ,
41+ class TestInt8DynamicActivationIntxWeight (TestCase ):
42+ @parametrize (
43+ "layout, weight_dtype, weight_mapping_type, weight_granularity" ,
44+ [
45+ (layout , weight_dtype , weight_mapping_type , weight_granularity )
46+ for layout in [
47+ PackedLinearInt8DynamicActivationIntxWeightLayout (),
48+ PackedLinearInt8DynamicActivationIntxWeightLayout (target = "universal" ),
49+ ]
50+ for weight_dtype in [
51+ torch .int1 ,
52+ torch .int2 ,
53+ torch .int3 ,
54+ torch .int4 ,
55+ torch .int5 ,
56+ torch .int6 ,
57+ torch .int7 ,
58+ torch .int8 ,
59+ ]
60+ for weight_mapping_type in [
61+ MappingType .SYMMETRIC ,
62+ MappingType .ASYMMETRIC ,
63+ MappingType .SYMMETRIC_NO_CLIPPING_ERR ,
64+ ]
65+ for weight_granularity in [
66+ PerGroup (128 ),
67+ PerAxis (0 ),
68+ ]
69+ ],
7370 )
7471 def test_accuracy (
7572 self , layout , weight_dtype , weight_mapping_type , weight_granularity
@@ -396,15 +393,12 @@ def test_export_QDQLayout(self):
396393 exported .graph_module .code
397394 )
398395
399- @parameterized .expand (
396+ @parametrize (
397+ "layout" ,
400398 [
401- param (layout = layout )
402- for layout in [
403- PackedLinearInt8DynamicActivationIntxWeightLayout (),
404- QDQLayout (),
405- ]
399+ PackedLinearInt8DynamicActivationIntxWeightLayout (),
400+ QDQLayout (),
406401 ],
407- name_func = lambda f , _ , params : f .__name__ + f"_{ params .kwargs } " ,
408402 )
409403 def test_serialization (self , layout ):
410404 layers = [
@@ -436,20 +430,16 @@ def test_serialization(self, layout):
436430 actual = model2 (activations )
437431 self .assertTrue (torch .allclose (expected , actual ))
438432
439- @parameterized .expand (
433+ @parametrize (
434+ "group_size, mapping_type, act_mapping_type" ,
440435 [
441- param (
442- group_size = group_size ,
443- mapping_type = mapping_type ,
444- act_mapping_type = act_mapping_type ,
445- )
436+ (group_size , mapping_type , act_mapping_type )
446437 for group_size , mapping_type , act_mapping_type in zip (
447438 [32 , 64 ],
448439 [MappingType .ASYMMETRIC , MappingType .SYMMETRIC ],
449440 [MappingType .ASYMMETRIC , MappingType .SYMMETRIC ],
450441 )
451442 ],
452- name_func = lambda f , _ , params : f .__name__ + f"_{ params .kwargs } " ,
453443 )
454444 def test_identical_to_Int8DynamicActivationInt4WeightConfig (
455445 self , group_size , mapping_type , act_mapping_type
@@ -490,15 +480,16 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
490480 sqnr = compute_error (model (activations ), model_copy (activations )).item ()
491481 self .assertTrue (sqnr == float ("inf" ))
492482
493- @parameterized .expand (
483+ @parametrize (
484+ "weight_dtype, group_size, mapping_type, act_mapping_type, scale_dtype, model_dtype" ,
494485 [
495- param (
496- weight_dtype = weight_dtype ,
497- group_size = group_size ,
498- mapping_type = mapping_type ,
499- act_mapping_type = act_mapping_type ,
500- scale_dtype = scale_dtype ,
501- model_dtype = model_dtype ,
486+ (
487+ weight_dtype ,
488+ group_size ,
489+ mapping_type ,
490+ act_mapping_type ,
491+ scale_dtype ,
492+ model_dtype ,
502493 )
503494 for weight_dtype in list (getattr (torch , f"int{ x } " ) for x in range (1 , 9 ))
504495 for group_size in [32 , 64 , 128 ]
@@ -507,7 +498,6 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
507498 for scale_dtype in [torch .float32 , torch .bfloat16 , torch .float16 ]
508499 for model_dtype in [torch .float32 , torch .bfloat16 , torch .float16 ]
509500 ],
510- name_func = lambda f , _ , params : f .__name__ + f"_{ params .kwargs } " ,
511501 )
512502 def test_identical_to_IntXQuantizationAwareTrainingConfig (
513503 self ,
@@ -582,18 +572,14 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
582572 sqnr = compute_error (prepared_out , converted_out ).item ()
583573 self .assertTrue (sqnr == float ("inf" ))
584574
585- @parameterized .expand (
575+ @parametrize (
576+ "group_size, scale_dtype, model_dtype" ,
586577 [
587- param (
588- group_size = group_size ,
589- scale_dtype = scale_dtype ,
590- model_dtype = model_dtype ,
591- )
578+ (group_size , scale_dtype , model_dtype )
592579 for group_size in [32 , 64 , 128 ]
593580 for scale_dtype in [torch .float32 , torch .bfloat16 , torch .float16 ]
594581 for model_dtype in [torch .float32 , torch .bfloat16 , torch .float16 ]
595582 ],
596- name_func = lambda f , _ , params : f .__name__ + f"_{ params .kwargs } " ,
597583 )
598584 def test_identical_to_Int8DynActInt4WeightQATQuantizer (
599585 self , group_size , scale_dtype , model_dtype
@@ -690,5 +676,7 @@ def test_moe_quant_intx(self):
690676 self .assertGreater (compute_error (out_qc , out ), 30 )
691677
692678
679+ instantiate_parametrized_tests (TestInt8DynamicActivationIntxWeight )
680+
693681if __name__ == "__main__" :
694682 unittest .main ()
0 commit comments