Skip to content

Commit 4872c4f

Browse files
authored
Move packing format used by int4 to int4_packing_format.py (#2946)
Summary: We found that there is not much reuse of packing format, so we now plan to define packing format for each of the dtype (int4, float8, intx), instead of having a global packing_format that's used by all the tensors this reduces the interference between different dtype configs. This doesn't change tensor subclass, so no BC changes for tensor subclass. For v2 of Int4WeightOnlyConfig, it breaks BC, but we don't have any official models saved with this config yet, so it's fine. We also didn't add bc testing for this since it's not finalized yet. We'll add that later. Test Plan: Regression tests: python test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py python test/core/test_config.py python test/integration/test_load_and_run_checkpoint.py Reviewers: Subscribers: Tasks: Tags:
1 parent 8901ff2 commit 4872c4f

File tree

11 files changed

+97
-49
lines changed

11 files changed

+97
-49
lines changed

test/core/test_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchao.quantization.quant_api import (
2727
FbgemmConfig,
2828
Float8DynamicActivationFloat8WeightConfig,
29+
Float8DynamicActivationInt4WeightConfig,
2930
Float8WeightOnlyConfig,
3031
FPXWeightOnlyConfig,
3132
GemliteUIntXWeightOnlyConfig,
@@ -49,13 +50,14 @@
4950
weight_dtype=torch.float8_e4m3fn,
5051
),
5152
UIntXWeightOnlyConfig(dtype=torch.uint1),
53+
Float8DynamicActivationInt4WeightConfig(),
5254
Int4DynamicActivationInt4WeightConfig(),
5355
Int4WeightOnlyConfig(
5456
group_size=32,
5557
),
5658
Int4WeightOnlyConfig(
5759
group_size=128,
58-
packing_format="tile_packed_to_4d",
60+
int4_packing_format="tile_packed_to_4d",
5961
int4_choose_qparams_algorithm="hqq",
6062
version=2,
6163
),

test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626

2727
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
2828
group_size=128,
29-
packing_format="marlin_sparse",
29+
int4_packing_format="marlin_sparse",
3030
version=2,
3131
)
3232

test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
def get_config(group_size):
2929
return Int4WeightOnlyConfig(
3030
group_size=group_size,
31-
packing_format="opaque",
31+
int4_packing_format="opaque",
3232
version=2,
3333
)
3434

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
def get_config(group_size):
2929
return Int4WeightOnlyConfig(
3030
group_size=group_size,
31-
packing_format="plain_int32",
31+
int4_packing_format="plain_int32",
3232
version=2,
3333
)
3434

test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,13 @@
2929

3030
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
3131
group_size=128,
32-
packing_format="preshuffled",
32+
int4_packing_format="preshuffled",
3333
version=2,
3434
)
3535

3636
# only 128 group_size is supported
3737
FP8_ACT_CONFIG = Float8DynamicActivationInt4WeightConfig(
38-
packing_format="preshuffled",
38+
int4_packing_format="preshuffled",
3939
)
4040

4141

test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,24 @@
1717
from torchao.quantization.quantize_.common import SupportsActivationPreScaling
1818
from torchao.quantization.utils import compute_error
1919
from torchao.testing.utils import TorchAOIntegrationTestCase
20-
from torchao.utils import is_sm_at_least_90, torch_version_at_least
20+
from torchao.utils import (
21+
_is_fbgemm_genai_gpu_available,
22+
is_sm_at_least_90,
23+
torch_version_at_least,
24+
)
2125

2226

2327
@unittest.skipIf(not torch_version_at_least("2.8.0"), "Need pytorch 2.8+")
2428
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2529
@unittest.skipIf(not is_sm_at_least_90(), "Nedd sm90+")
30+
@unittest.skipIf(
31+
not _is_fbgemm_genai_gpu_available(), "Requires fbgemm-gpu-genai >= 1.2.0"
32+
)
2633
class TestInt4Tensor(TorchAOIntegrationTestCase):
2734
def setUp(self):
2835
self.config = Int4WeightOnlyConfig(
2936
group_size=128,
30-
packing_format="plain",
37+
int4_packing_format="plain",
3138
version=2,
3239
)
3340
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []

test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,13 +24,13 @@
2424

2525
INT4_CONFIG = Int4WeightOnlyConfig(
2626
group_size=128,
27-
packing_format="tile_packed_to_4d",
27+
int4_packing_format="tile_packed_to_4d",
2828
version=2,
2929
)
3030

3131
INT4_HQQ_CONFIG = Int4WeightOnlyConfig(
3232
group_size=128,
33-
packing_format="tile_packed_to_4d",
33+
int4_packing_format="tile_packed_to_4d",
3434
int4_choose_qparams_algorithm="hqq",
3535
version=2,
3636
)

torchao/quantization/quant_api.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@
7575
Int4ChooseQParamsAlgorithm,
7676
Int4MarlinSparseTensor,
7777
Int4OpaqueTensor,
78+
Int4PackingFormat,
7879
Int4PlainInt32Tensor,
7980
Int4PreshuffledTensor,
8081
Int4Tensor,
@@ -1075,7 +1076,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
10751076
Note:
10761077
Current state for Int4WeightOnlyConfig is that it supports both v1 (legacy) and v2
10771078
1078-
For v2 (version = 2), only `group_size`, `packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored
1079+
For v2 (version = 2), only `group_size`, `int4_packing_format`, `int4_choose_qparams_algorithm` and `set_inductor_config` are valid, all other args will be ignored
10791080
For v1 (version = 1), only `group_size`, `layout`, `use_hqq`, `zero_point_domain`, `preserve_zero` and `set_inductor_config` are valid, we plan to deprecate v1 in torchao 0.15 to make this config
10801081
less confusing
10811082
"""
@@ -1087,7 +1088,7 @@ class Int4WeightOnlyConfig(AOBaseConfig):
10871088
set_inductor_config: bool = True
10881089
preserve_zero: Optional[bool] = None
10891090
# only used in version >= 2
1090-
packing_format: PackingFormat = PackingFormat.PLAIN
1091+
int4_packing_format: Int4PackingFormat = Int4PackingFormat.PLAIN
10911092
int4_choose_qparams_algorithm: Int4ChooseQParamsAlgorithm = (
10921093
Int4ChooseQParamsAlgorithm.TINYGEMM
10931094
)
@@ -1113,7 +1114,7 @@ def _int4_weight_only_quantize_tensor(weight, config):
11131114
use_hqq = config.use_hqq
11141115
int4_choose_qparams_algorithm = config.int4_choose_qparams_algorithm
11151116
zero_point_domain = config.zero_point_domain
1116-
packing_format = config.packing_format
1117+
int4_packing_format = config.int4_packing_format
11171118

11181119
if weight.shape[-1] % group_size != 0:
11191120
logger.info(
@@ -1127,50 +1128,50 @@ def _int4_weight_only_quantize_tensor(weight, config):
11271128
block_size = list(block_size)
11281129

11291130
if int4_choose_qparams_algorithm == Int4ChooseQParamsAlgorithm.HQQ:
1130-
assert packing_format == PackingFormat.TILE_PACKED_TO_4D, (
1131-
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {packing_format}, it's only supported by PackingFormat.TILE_PACKED_TO_4D curretnly"
1131+
assert int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D, (
1132+
f"Int4ChooseQParamsAlgorithm.HQQ is not supported by packing format {int4_packing_format}, it's only supported by Int4PackingFormat.TILE_PACKED_TO_4D curretnly"
11321133
)
11331134

1134-
if packing_format == PackingFormat.PRESHUFFLED:
1135+
if int4_packing_format == Int4PackingFormat.PRESHUFFLED:
11351136
new_weight = Int4PreshuffledTensor.from_hp(
11361137
weight,
11371138
block_size,
11381139
activation_dtype=torch.bfloat16,
11391140
)
11401141
return new_weight
1141-
elif packing_format == PackingFormat.PLAIN:
1142+
elif int4_packing_format == Int4PackingFormat.PLAIN:
11421143
new_weight = Int4Tensor.from_hp(
11431144
weight,
11441145
block_size,
11451146
)
11461147
return new_weight
1147-
elif packing_format == PackingFormat.PLAIN_INT32:
1148+
elif int4_packing_format == Int4PackingFormat.PLAIN_INT32:
11481149
new_weight = Int4PlainInt32Tensor.from_hp(
11491150
weight,
11501151
block_size,
11511152
)
11521153
return new_weight
1153-
elif packing_format == PackingFormat.MARLIN_SPARSE:
1154+
elif int4_packing_format == Int4PackingFormat.MARLIN_SPARSE:
11541155
new_weight = Int4MarlinSparseTensor.from_hp(
11551156
weight,
11561157
block_size,
11571158
)
11581159
return new_weight
1159-
elif packing_format == PackingFormat.OPAQUE:
1160+
elif int4_packing_format == Int4PackingFormat.OPAQUE:
11601161
new_weight = Int4OpaqueTensor.from_hp(
11611162
weight,
11621163
block_size,
11631164
)
11641165
return new_weight
1165-
elif packing_format == PackingFormat.TILE_PACKED_TO_4D:
1166+
elif int4_packing_format == Int4PackingFormat.TILE_PACKED_TO_4D:
11661167
new_weight = Int4TilePackedTo4dTensor.from_hp(
11671168
weight,
11681169
block_size,
11691170
int4_choose_qparams_algorithm=int4_choose_qparams_algorithm,
11701171
)
11711172
return new_weight
11721173
else:
1173-
raise ValueError(f"Unsupported packing format: {packing_format}")
1174+
raise ValueError(f"Unsupported int4 packing format: {int4_packing_format}")
11741175

11751176
assert config.version == 1
11761177

@@ -1254,10 +1255,10 @@ class Float8DynamicActivationInt4WeightConfig(AOBaseConfig):
12541255
and above and no benefits of making it bigger)
12551256
12561257
Args:
1257-
`packing_format`: how the weight is packed, only preshuffled is supported
1258+
`int4_packing_format`: how the weight is packed, only preshuffled is supported
12581259
"""
12591260

1260-
packing_format: PackingFormat = "preshuffled"
1261+
int4_packing_format: Int4PackingFormat = "preshuffled"
12611262

12621263

12631264
@register_quantize_module_handler(Float8DynamicActivationInt4WeightConfig)
@@ -1268,10 +1269,10 @@ def _float8_dynamic_activation_int4_weight_transform(
12681269
"applying int8 weight only quant requires module to have weight attribute"
12691270
+ " but {module} does not have one"
12701271
)
1271-
packing_format = config.packing_format
1272+
int4_packing_format = config.int4_packing_format
12721273

1273-
assert packing_format == "preshuffled", (
1274-
f"only preshuffled packing_format supported right now, got: {packing_format}"
1274+
assert int4_packing_format == "preshuffled", (
1275+
f"only preshuffled int4_packing_format supported right now, got: {int4_packing_format}"
12751276
)
12761277
weight = module.weight
12771278
group_size = 128

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ class PackingFormat(str, Enum):
1616

1717
"""
1818
plain means the format that quantized Tensor data lays out elements in Tensor sequentially,
19-
for example: for a Tensor of shape (4, 6):
19+
for example: for a Tensor of shape (4, 6):
2020
a_0_0, a_0_1, ..., a_0_5,
2121
...
2222
a_3_0, a_3_1, ..., a_3_5
@@ -26,32 +26,11 @@ class PackingFormat(str, Enum):
2626
"""
2727
PLAIN = "plain"
2828

29-
"""
30-
preshuffled is referring to the preshuffled format used by fbgemm kernels
31-
"""
32-
PRESHUFFLED = "preshuffled"
33-
34-
"""
35-
marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization
36-
"""
37-
MARLIN_SPARSE = "marlin_sparse"
38-
3929
"""
4030
Unpacked to int8 means the subbyte quantized data is stored as int8
4131
"""
4232
UNPACKED_TO_INT8 = "unpacked_to_int8"
4333

44-
"""
45-
plain_int32 is referring to the format used by int4 weight-only quantization.
46-
which is a groupwise quantization format 2*int4 is store in a byte and 4*(int4*2) is stored in a int32.
47-
"""
48-
PLAIN_INT32 = "plain_int32"
49-
50-
"""
51-
tile_packed_to_4d is referring to the format used by tinygemm kernels for int4 quantization
52-
"""
53-
TILE_PACKED_TO_4D = "tile_packed_to_4d"
54-
5534
"""
5635
Opaque packing format that's used for tensors that does not have a predefined packing format
5736
(that may be decided on hardware, tensor shape, library availability etc.) and it's not

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from .int4.int4_opaque_tensor import (
1010
Int4OpaqueTensor,
1111
)
12+
from .int4.int4_packing_format import Int4PackingFormat
1213
from .int4.int4_plain_int32_tensor import (
1314
Int4PlainInt32Tensor,
1415
)
@@ -39,4 +40,5 @@
3940
"IntxUnpackedTensor",
4041
"IntxUnpackedToInt8Tensor",
4142
"Int4ChooseQParamsAlgorithm",
43+
"Int4PackingFormat",
4244
]

0 commit comments

Comments
 (0)