Skip to content

Commit e73434e

Browse files
committed
Add Int4TensorCoreTilePackedTensor for tensor core tiled int4 quantization
This commit introduces Int4TensorCoreTilePackedTensor, a new tensor subclass for int4 weight-only quantization using tensor core tiled packing format. Key features: - Implements tensor core tiled packing for efficient computation on tensor cores - Uses tinygemm quantization path instead of HQQ for consistency - Supports PackingFormat.TENSOR_CORE_TILE_PACKED in Int4WeightOnlyConfig version 2 - Optimized for tinygemm int4mm kernel (_weight_int4pack_mm) - Includes comprehensive test suite The implementation follows the same pattern as other int4 tensor subclasses but uses a specialized packing format optimized for tensor core matrix multiplication performance. Changes: - Add Int4TensorCoreTilePackedTensor implementation - Update Int4WeightOnlyConfig version 2 to support TENSOR_CORE_TILE_PACKED packing format - Add TENSOR_CORE_TILE_PACKED to PackingFormat enum - Replace HQQ quantization with _quantize_affine_tinygemm for consistency - Add comprehensive tests including serialization, different group sizes, and error conditions - Update __init__.py files to export new tensor class
1 parent 751d7f6 commit e73434e

File tree

7 files changed

+581
-55
lines changed

7 files changed

+581
-55
lines changed
Lines changed: 170 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,170 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import tempfile
8+
import unittest
9+
10+
import torch
11+
from torch.testing._internal.common_utils import (
12+
TestCase,
13+
instantiate_parametrized_tests,
14+
parametrize,
15+
run_tests,
16+
)
17+
18+
from torchao.quantization import Int4WeightOnlyConfig, quantize_
19+
from torchao.quantization.quantize_.common.packing_format import PackingFormat
20+
from torchao.quantization.quantize_.workflows.int4.int4_tensor_core_tile_packed_tensor import (
21+
Int4TensorCoreTilePackedTensor,
22+
)
23+
from torchao.quantization.utils import compute_error
24+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4
25+
26+
TENSOR_CORE_TILED_CONFIG = Int4WeightOnlyConfig(
27+
group_size=128,
28+
packing_format=PackingFormat.TENSOR_CORE_TILE_PACKED,
29+
version=2,
30+
)
31+
32+
33+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_4, "Need pytorch 2.4+")
34+
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
35+
class TestInt4TensorCoreTilePackedTensor(TestCase):
36+
def setUp(self):
37+
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
38+
39+
@parametrize("config", [TENSOR_CORE_TILED_CONFIG])
40+
@parametrize(
41+
"sizes",
42+
[
43+
((128,), 256, 128),
44+
((32, 128), 512, 128),
45+
((2, 32, 128), 256, 128),
46+
],
47+
)
48+
def test_linear(self, config, sizes):
49+
dtype = torch.bfloat16
50+
device = "cuda"
51+
52+
M, N, K = sizes
53+
input = torch.randn(*M, K, dtype=dtype, device=device)
54+
linear = torch.nn.Linear(K, N, dtype=dtype, device=device)
55+
56+
original = linear(input)
57+
quantize_(linear, config)
58+
quantized = linear(input)
59+
self.assertTrue(compute_error(original, quantized) > 1)
60+
61+
compiled_linear = torch.compile(linear)
62+
quantized_and_compiled = compiled_linear(input)
63+
self.assertTrue(compute_error(original, quantized_and_compiled) > 1)
64+
65+
def test_from_hp(self):
66+
"""Test creating Int4TensorCoreTilePackedTensor from high precision tensor"""
67+
dtype = torch.bfloat16
68+
device = "cuda"
69+
hp_tensor = torch.randn(256, 128, dtype=dtype, device=device)
70+
block_size = (1, 64)
71+
72+
tensor = Int4TensorCoreTilePackedTensor.from_hp(hp_tensor, block_size)
73+
74+
self.assertEqual(tensor.shape, hp_tensor.shape)
75+
self.assertEqual(tensor.block_size, block_size)
76+
self.assertEqual(tensor.device.type, device)
77+
self.assertEqual(tensor.dtype, dtype)
78+
79+
@parametrize("config", [TENSOR_CORE_TILED_CONFIG])
80+
def test_to_device(self, config):
81+
for device in self.GPU_DEVICES:
82+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
83+
quantize_(linear.cuda(), config)
84+
linear.to(device)
85+
86+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
87+
quantize_(linear.cuda(), config)
88+
linear.to(device=device)
89+
90+
@parametrize("config", [TENSOR_CORE_TILED_CONFIG])
91+
def test_module_path(self, config):
92+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
93+
quantize_(linear.cuda(), config)
94+
self.assertEqual(
95+
str(type(linear.weight)),
96+
"<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>",
97+
)
98+
99+
def test_serialization(self):
100+
"""Test saving and loading the tensor directly and via state_dict"""
101+
dtype = torch.bfloat16
102+
device = "cuda"
103+
hp_tensor = torch.randn(128, 256, dtype=dtype, device=device)
104+
block_size = (1, 64)
105+
106+
tensor = Int4TensorCoreTilePackedTensor.from_hp(hp_tensor, block_size)
107+
108+
# Test direct tensor serialization
109+
with tempfile.NamedTemporaryFile() as f:
110+
torch.save(tensor, f)
111+
f.seek(0)
112+
loaded_tensor = torch.load(f)
113+
114+
self.assertEqual(loaded_tensor.shape, tensor.shape)
115+
self.assertEqual(loaded_tensor.block_size, tensor.block_size)
116+
self.assertEqual(
117+
str(type(loaded_tensor)),
118+
"<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>",
119+
)
120+
121+
# Test state_dict serialization
122+
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16)
123+
quantize_(linear.cuda(), TENSOR_CORE_TILED_CONFIG)
124+
125+
with tempfile.NamedTemporaryFile() as f:
126+
torch.save(linear.state_dict(), f)
127+
f.seek(0)
128+
state_dict = torch.load(f)
129+
self.assertEqual(
130+
str(type(state_dict["weight"])),
131+
"<class 'torchao.quantization.Int4TensorCoreTilePackedTensor'>",
132+
)
133+
134+
@parametrize("group_size", [32, 64, 128])
135+
def test_different_group_sizes(self, group_size):
136+
"""Test with different group sizes"""
137+
dtype = torch.bfloat16
138+
device = "cuda"
139+
hp_tensor = torch.randn(256, 512, dtype=dtype, device=device)
140+
block_size = (1, group_size)
141+
142+
tensor = Int4TensorCoreTilePackedTensor.from_hp(hp_tensor, block_size)
143+
144+
self.assertEqual(tensor.shape, hp_tensor.shape)
145+
self.assertEqual(tensor.block_size, block_size)
146+
147+
def test_error_conditions(self):
148+
"""Test various error conditions"""
149+
dtype = torch.bfloat16
150+
device = "cuda"
151+
hp_tensor = torch.randn(128, 256, dtype=dtype, device=device)
152+
153+
# Test invalid block_size length
154+
with self.assertRaises(AssertionError):
155+
Int4TensorCoreTilePackedTensor.from_hp(
156+
hp_tensor, (64,)
157+
) # block_size length mismatch
158+
159+
# Test non-groupwise quantization
160+
with self.assertRaises(AssertionError):
161+
Int4TensorCoreTilePackedTensor.from_hp(
162+
hp_tensor, (2, 64)
163+
) # first element should be 1
164+
165+
166+
instantiate_parametrized_tests(TestInt4TensorCoreTilePackedTensor)
167+
168+
169+
if __name__ == "__main__":
170+
run_tests()

torchao/quantization/__init__.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,4 @@
1-
from torchao.kernel import (
2-
int_scaled_matmul,
3-
safe_int_mm,
4-
)
1+
from torchao.kernel import int_scaled_matmul, safe_int_mm
52

63
from .autoquant import (
74
ALL_AUTOQUANT_CLASS_LIST,
@@ -13,18 +10,8 @@
1310
OTHER_AUTOQUANT_CLASS_LIST,
1411
autoquant,
1512
)
16-
from .GPTQ import (
17-
Int4WeightOnlyGPTQQuantizer,
18-
MultiTensor,
19-
MultiTensorInputRecorder,
20-
)
21-
from .granularity import (
22-
PerAxis,
23-
PerGroup,
24-
PerRow,
25-
PerTensor,
26-
PerToken,
27-
)
13+
from .GPTQ import Int4WeightOnlyGPTQQuantizer, MultiTensor, MultiTensorInputRecorder
14+
from .granularity import PerAxis, PerGroup, PerRow, PerTensor, PerToken
2815
from .linear_activation_quantized_tensor import (
2916
LinearActivationQuantizedTensor,
3017
to_linear_activation_quantized,
@@ -37,10 +24,7 @@
3724
Int8DynActInt4WeightLinear,
3825
Int8DynActInt4WeightQuantizer,
3926
)
40-
from .observer import (
41-
AffineQuantizedMinMaxObserver,
42-
AffineQuantizedObserverBase,
43-
)
27+
from .observer import AffineQuantizedMinMaxObserver, AffineQuantizedObserverBase
4428
from .quant_api import (
4529
CutlassInt4PackedLayout,
4630
FbgemmConfig,
@@ -93,6 +77,7 @@
9377
Int4MarlinSparseTensor,
9478
Int4PreshuffledTensor,
9579
Int4Tensor,
80+
Int4TensorCoreTilePackedTensor,
9681
)
9782
from .smoothquant import (
9883
SmoothFakeDynamicallyQuantizedLinear,
@@ -105,9 +90,7 @@
10590
from .subclass import * # noqa: F403
10691
from .transform_module import register_quantize_module_handler
10792
from .unified import Quantizer, TwoStepQuantizer
108-
from .utils import (
109-
compute_error,
110-
)
93+
from .utils import compute_error
11194
from .weight_only import WeightOnlyInt8QuantLinear
11295

11396
# TODO: remove after migration of APIs are done
@@ -161,6 +144,7 @@
161144
"Int4Tensor",
162145
"Int4PreshuffledTensor",
163146
"Int4MarlinSparseTensor",
147+
"Int4TensorCoreTilePackedTensor",
164148
"Float8Tensor",
165149
# smooth quant - subject to change
166150
"get_scale",

torchao/quantization/quant_api.py

Lines changed: 22 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -66,15 +66,13 @@
6666
LinearActivationWeightObservedTensor,
6767
)
6868
from torchao.quantization.observer import AffineQuantizedObserverBase, get_block_size
69-
from torchao.quantization.quantize_.common import (
70-
KernelPreference,
71-
PackingFormat,
72-
)
69+
from torchao.quantization.quantize_.common import KernelPreference, PackingFormat
7370
from torchao.quantization.quantize_.workflows import (
7471
Float8Tensor,
7572
Int4MarlinSparseTensor,
7673
Int4PreshuffledTensor,
7774
Int4Tensor,
75+
Int4TensorCoreTilePackedTensor,
7876
QuantizeTensorToFloat8Kwargs,
7977
)
8078
from torchao.quantization.transform_module import (
@@ -92,35 +90,16 @@
9290
)
9391

9492
from .autoquant import AutoQuantizableLinearWeight, autoquant
95-
from .GPTQ import (
96-
Int4WeightOnlyGPTQQuantizer,
97-
)
98-
from .granularity import (
99-
Granularity,
100-
PerAxis,
101-
PerGroup,
102-
PerRow,
103-
PerTensor,
104-
)
93+
from .GPTQ import Int4WeightOnlyGPTQQuantizer
94+
from .granularity import Granularity, PerAxis, PerGroup, PerRow, PerTensor
10595
from .linear_activation_quantized_tensor import (
10696
LinearActivationQuantizedTensor,
10797
to_linear_activation_quantized,
10898
)
109-
from .linear_quant_modules import (
110-
Int4WeightOnlyQuantizer,
111-
Int8DynActInt4WeightQuantizer,
112-
)
113-
from .qat import (
114-
intx_quantization_aware_training,
115-
)
116-
from .quant_primitives import (
117-
_DTYPE_TO_QVALUE_BOUNDS,
118-
MappingType,
119-
ZeroPointDomain,
120-
)
121-
from .subclass import (
122-
QuantizedLinearWeightBase,
123-
)
99+
from .linear_quant_modules import Int4WeightOnlyQuantizer, Int8DynActInt4WeightQuantizer
100+
from .qat import intx_quantization_aware_training
101+
from .quant_primitives import _DTYPE_TO_QVALUE_BOUNDS, MappingType, ZeroPointDomain
102+
from .subclass import QuantizedLinearWeightBase
124103
from .unified import Quantizer, TwoStepQuantizer
125104
from .utils import _get_per_token_block_size
126105

@@ -1075,6 +1054,12 @@ def _int4_weight_only_quantize_tensor(weight, config):
10751054
block_size,
10761055
)
10771056
return new_weight
1057+
elif packing_format == PackingFormat.TENSOR_CORE_TILE_PACKED:
1058+
new_weight = Int4TensorCoreTilePackedTensor.from_hp(
1059+
weight,
1060+
block_size,
1061+
)
1062+
return new_weight
10781063
else:
10791064
raise ValueError(f"Unsupported packing format: {packing_format}")
10801065

@@ -1449,10 +1434,12 @@ def int8_dynamic_activation_int8_semi_sparse_weight():
14491434
Applies int8 dnynamic symmetric per-token activation and int8 per-channel weight
14501435
quantization + 2:4 sparsity to linear layers.
14511436
"""
1452-
warnings.warn("""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
1437+
warnings.warn(
1438+
"""int8_dyanmic_activation_int8_semi_sparse_weight() will be deprecated at a later release. Please use the layout kwarg in int8_dynamic_activation_int8_weight instead.
14531439
14541440
from torchao.dtypes import SemiSparseLayout
1455-
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()""")
1441+
int8_dynamic_activation_int8_weight(layout=SemiSparseLayout()"""
1442+
)
14561443

14571444
return int8_dynamic_activation_int8_weight(layout=SemiSparseLayout())
14581445

@@ -2000,7 +1987,10 @@ def __post_init__(self):
20001987
assert self.granularity.axis == 0, (
20011988
f"axis must be 0 with PerAxis, but got {self.granularity.axis}"
20021989
)
2003-
assert self.mapping_type in [MappingType.ASYMMETRIC, MappingType.SYMMETRIC], (
1990+
assert self.mapping_type in [
1991+
MappingType.ASYMMETRIC,
1992+
MappingType.SYMMETRIC,
1993+
], (
20041994
f"mapping_type must be MappingType.ASYMMETRIC or MappingType.SYMMETRIC, but got {self.mapping_type}"
20051995
)
20061996

torchao/quantization/quantize_/common/packing_format.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,8 @@ class PackingFormat(str, Enum):
3535
marlin_sparse is referring to the format used by marlin kernels, only supports symmetric quantization
3636
"""
3737
MARLIN_SPARSE = "marlin_sparse"
38+
39+
"""
40+
tensor_core_tile_packed is referring to the format used by tensor core tiled kernels for int4 quantization
41+
"""
42+
TENSOR_CORE_TILE_PACKED = "tensor_core_tile_packed"

torchao/quantization/quantize_/workflows/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,13 @@
1111
from .int4.int4_tensor import (
1212
Int4Tensor,
1313
)
14+
from .int4.int4_tensor_core_tile_packed_tensor import Int4TensorCoreTilePackedTensor
1415

1516
__all__ = [
1617
"Int4Tensor",
1718
"Int4PreshuffledTensor",
1819
"Int4MarlinSparseTensor",
20+
"Int4TensorCoreTilePackedTensor",
1921
"Float8Tensor",
2022
"QuantizeTensorToFloat8Kwargs",
2123
]
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,11 @@
1+
from .int4_marlin_sparse_tensor import Int4MarlinSparseTensor
12
from .int4_preshuffled_tensor import Int4PreshuffledTensor
23
from .int4_tensor import Int4Tensor
4+
from .int4_tensor_core_tile_packed_tensor import Int4TensorCoreTilePackedTensor
35

46
__all__ = [
57
"Int4PreshuffledTensor",
68
"Int4Tensor",
9+
"Int4MarlinSparseTensor",
10+
"Int4TensorCoreTilePackedTensor",
711
]

0 commit comments

Comments
 (0)