Skip to content

Commit 9a770a5

Browse files
authored
Fix parametrized tests
Differential Revision: D82487961 Pull Request resolved: #3007
1 parent 4dffb40 commit 9a770a5

File tree

1 file changed

+55
-67
lines changed

1 file changed

+55
-67
lines changed

test/quantization/test_int8_dynamic_activation_intx_weight_config_v1.py

Lines changed: 55 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,12 @@
1010
import unittest
1111

1212
import torch
13-
from parameterized import param, parameterized
1413
from torch.testing import FileCheck
14+
from torch.testing._internal.common_utils import (
15+
TestCase,
16+
instantiate_parametrized_tests,
17+
parametrize,
18+
)
1519

1620
from torchao.dtypes import PackedLinearInt8DynamicActivationIntxWeightLayout, QDQLayout
1721
from torchao.quantization.granularity import PerAxis, PerGroup
@@ -34,42 +38,35 @@
3438

3539

3640
@unittest.skipIf(not _is_kernel_library_loaded(), "Kernel library not loaded")
37-
class TestInt8DynamicActivationIntxWeight(unittest.TestCase):
38-
TEST_ACCURACY_CASES = [
39-
param(
40-
layout=layout,
41-
weight_dtype=weight_dtype,
42-
weight_mapping_type=weight_mapping_type,
43-
weight_granularity=weight_granularity,
44-
)
45-
for layout in [
46-
PackedLinearInt8DynamicActivationIntxWeightLayout(),
47-
PackedLinearInt8DynamicActivationIntxWeightLayout(target="universal"),
48-
]
49-
for weight_dtype in [
50-
torch.int1,
51-
torch.int2,
52-
torch.int3,
53-
torch.int4,
54-
torch.int5,
55-
torch.int6,
56-
torch.int7,
57-
torch.int8,
58-
]
59-
for weight_mapping_type in [
60-
MappingType.SYMMETRIC,
61-
MappingType.ASYMMETRIC,
62-
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
63-
]
64-
for weight_granularity in [
65-
PerGroup(128),
66-
PerAxis(0),
67-
]
68-
]
69-
70-
@parameterized.expand(
71-
TEST_ACCURACY_CASES,
72-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
41+
class TestInt8DynamicActivationIntxWeight(TestCase):
42+
@parametrize(
43+
"layout, weight_dtype, weight_mapping_type, weight_granularity",
44+
[
45+
(layout, weight_dtype, weight_mapping_type, weight_granularity)
46+
for layout in [
47+
PackedLinearInt8DynamicActivationIntxWeightLayout(),
48+
PackedLinearInt8DynamicActivationIntxWeightLayout(target="universal"),
49+
]
50+
for weight_dtype in [
51+
torch.int1,
52+
torch.int2,
53+
torch.int3,
54+
torch.int4,
55+
torch.int5,
56+
torch.int6,
57+
torch.int7,
58+
torch.int8,
59+
]
60+
for weight_mapping_type in [
61+
MappingType.SYMMETRIC,
62+
MappingType.ASYMMETRIC,
63+
MappingType.SYMMETRIC_NO_CLIPPING_ERR,
64+
]
65+
for weight_granularity in [
66+
PerGroup(128),
67+
PerAxis(0),
68+
]
69+
],
7370
)
7471
def test_accuracy(
7572
self, layout, weight_dtype, weight_mapping_type, weight_granularity
@@ -396,15 +393,12 @@ def test_export_QDQLayout(self):
396393
exported.graph_module.code
397394
)
398395

399-
@parameterized.expand(
396+
@parametrize(
397+
"layout",
400398
[
401-
param(layout=layout)
402-
for layout in [
403-
PackedLinearInt8DynamicActivationIntxWeightLayout(),
404-
QDQLayout(),
405-
]
399+
PackedLinearInt8DynamicActivationIntxWeightLayout(),
400+
QDQLayout(),
406401
],
407-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
408402
)
409403
def test_serialization(self, layout):
410404
layers = [
@@ -436,20 +430,16 @@ def test_serialization(self, layout):
436430
actual = model2(activations)
437431
self.assertTrue(torch.allclose(expected, actual))
438432

439-
@parameterized.expand(
433+
@parametrize(
434+
"group_size, mapping_type, act_mapping_type",
440435
[
441-
param(
442-
group_size=group_size,
443-
mapping_type=mapping_type,
444-
act_mapping_type=act_mapping_type,
445-
)
436+
(group_size, mapping_type, act_mapping_type)
446437
for group_size, mapping_type, act_mapping_type in zip(
447438
[32, 64],
448439
[MappingType.ASYMMETRIC, MappingType.SYMMETRIC],
449440
[MappingType.ASYMMETRIC, MappingType.SYMMETRIC],
450441
)
451442
],
452-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
453443
)
454444
def test_identical_to_Int8DynamicActivationInt4WeightConfig(
455445
self, group_size, mapping_type, act_mapping_type
@@ -490,15 +480,16 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
490480
sqnr = compute_error(model(activations), model_copy(activations)).item()
491481
self.assertTrue(sqnr == float("inf"))
492482

493-
@parameterized.expand(
483+
@parametrize(
484+
"weight_dtype, group_size, mapping_type, act_mapping_type, scale_dtype, model_dtype",
494485
[
495-
param(
496-
weight_dtype=weight_dtype,
497-
group_size=group_size,
498-
mapping_type=mapping_type,
499-
act_mapping_type=act_mapping_type,
500-
scale_dtype=scale_dtype,
501-
model_dtype=model_dtype,
486+
(
487+
weight_dtype,
488+
group_size,
489+
mapping_type,
490+
act_mapping_type,
491+
scale_dtype,
492+
model_dtype,
502493
)
503494
for weight_dtype in list(getattr(torch, f"int{x}") for x in range(1, 9))
504495
for group_size in [32, 64, 128]
@@ -507,7 +498,6 @@ def test_identical_to_Int8DynamicActivationInt4WeightConfig(
507498
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
508499
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
509500
],
510-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
511501
)
512502
def test_identical_to_IntXQuantizationAwareTrainingConfig(
513503
self,
@@ -582,18 +572,14 @@ def test_identical_to_IntXQuantizationAwareTrainingConfig(
582572
sqnr = compute_error(prepared_out, converted_out).item()
583573
self.assertTrue(sqnr == float("inf"))
584574

585-
@parameterized.expand(
575+
@parametrize(
576+
"group_size, scale_dtype, model_dtype",
586577
[
587-
param(
588-
group_size=group_size,
589-
scale_dtype=scale_dtype,
590-
model_dtype=model_dtype,
591-
)
578+
(group_size, scale_dtype, model_dtype)
592579
for group_size in [32, 64, 128]
593580
for scale_dtype in [torch.float32, torch.bfloat16, torch.float16]
594581
for model_dtype in [torch.float32, torch.bfloat16, torch.float16]
595582
],
596-
name_func=lambda f, _, params: f.__name__ + f"_{params.kwargs}",
597583
)
598584
def test_identical_to_Int8DynActInt4WeightQATQuantizer(
599585
self, group_size, scale_dtype, model_dtype
@@ -690,5 +676,7 @@ def test_moe_quant_intx(self):
690676
self.assertGreater(compute_error(out_qc, out), 30)
691677

692678

679+
instantiate_parametrized_tests(TestInt8DynamicActivationIntxWeight)
680+
693681
if __name__ == "__main__":
694682
unittest.main()

0 commit comments

Comments
 (0)