Skip to content

Commit 0ad98af

Browse files
committed
Bump int4 weight only config version to 2
Summary: Current Int4WeightOnlyConfig has version 1 and 2, and default is 1, this PR changes the default to 2 and made modification to callsites. For the Int4WeightOnlyConfig that's using the old configuration, we added explicit `version=1`, we can migrate the callsite to use the version 2 separately For READMEs we migrate the usage to version 2 directly Deprecation: TODO Test Plan: Regression tests: python test/dtypes/test_affine_quantized.py python test/quantization/test_quant_api.py python test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py Reviewers: Subscribers: Tasks: Tags:
1 parent 4872c4f commit 0ad98af

24 files changed

+174
-74
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _untie_weights_and_save_locally(model_id):
206206

207207
_int4_quant_code = """
208208
from torchao.quantization import Int4WeightOnlyConfig
209-
quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2)
209+
quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")
210210
quantization_config = TorchAoConfig(quant_type=quant_config)
211211
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
212212
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -256,7 +256,7 @@ def _untie_weights_and_save_locally(model_id):
256256
)
257257
tokenizer = AutoTokenizer.from_pretrained(model_id)
258258
259-
base_config = Int4WeightOnlyConfig(group_size=128, version=2)
259+
base_config = Int4WeightOnlyConfig(group_size=128)
260260
quant_config = AWQConfig(base_config, step="prepare")
261261
quantize_(
262262
model,
@@ -635,7 +635,6 @@ def quantize_and_upload(
635635
group_size=128,
636636
packing_format="tile_packed_to_4d",
637637
int4_choose_qparams_algorithm="hqq",
638-
version=2,
639638
),
640639
"INT8-INT4": ModuleFqnToConfig(
641640
{
@@ -669,7 +668,7 @@ def quantize_and_upload(
669668
)
670669
tokenizer = AutoTokenizer.from_pretrained(model_id)
671670

672-
base_config = Int4WeightOnlyConfig(group_size=128, version=2)
671+
base_config = Int4WeightOnlyConfig(group_size=128)
673672
quant_config = AWQConfig(base_config, step="prepare")
674673
quantize_(
675674
model,

benchmarks/microbenchmarks/test/test_benchmark_inference.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ def test_run_inference_with_semi_sparse_marlin(self, mock_string_to_config):
5858

5959
# Test with semi-sparse config
6060
mock_string_to_config.return_value = Int4WeightOnlyConfig(
61-
layout=MarlinSparseLayout()
61+
layout=MarlinSparseLayout(), version=1
6262
)
6363
config = BenchmarkConfig(
6464
quantization="marlin",

benchmarks/microbenchmarks/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def string_to_config(
206206
128,
207207
256,
208208
], f"int4wo group_size needs to be one of [32,64,128,256] but got {group_size}"
209-
return Int4WeightOnlyConfig(group_size=group_size, use_hqq=use_hqq)
209+
return Int4WeightOnlyConfig(group_size=group_size, use_hqq=True, version=1)
210210
elif "int8adq-int4w-symm" in quantization:
211211
from torchao.dtypes import CutlassInt4PackedLayout
212212

docs/source/torchao_vllm_integration.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ from torchao.quantization import Int4WeightOnlyConfig
4545
config = Int4WeightOnlyConfig(
4646
group_size=128,
4747
use_hqq=True,
48+
version=1,
4849
)
4950
assert isinstance(config, AOBaseConfig)
5051
```
@@ -81,7 +82,7 @@ from torchao.quantization import Int4WeightOnlyConfig
8182

8283
# Create quantization configuration
8384
quantization_config = TorchAoConfig(
84-
quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True)
85+
quant_type=Int4WeightOnlyConfig(group_size=128, use_hqq=True, version=1)
8586
)
8687

8788
# Load and automatically quantize the model

test/dtypes/test_affine_quantized.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@
3232
Int8DynamicActivationInt8WeightConfig,
3333
float8_weight_only,
3434
int4_dynamic_activation_int4_weight,
35-
int4_weight_only,
3635
int8_dynamic_activation_int4_weight,
3736
int8_dynamic_activation_int8_weight,
3837
int8_weight_only,
@@ -66,22 +65,23 @@ def get_quantization_functions(
6665
if do_int4:
6766
if check_cpu_version(device):
6867
base_functions.append(
69-
int4_weight_only(group_size=32, layout=Int4CPULayout())
68+
Int4WeightOnlyConfig(group_size=32, layout=Int4CPULayout(), version=1)
7069
)
7170
elif check_xpu_version(device):
7271
base_functions.append(
73-
int4_weight_only(group_size=32, layout=Int4XPULayout())
72+
Int4WeightOnlyConfig(group_size=32, layout=Int4XPULayout(), version=1)
7473
)
7574
if int4_zp_int:
7675
base_functions.append(
77-
int4_weight_only(
76+
Int4WeightOnlyConfig(
7877
group_size=32,
7978
layout=Int4XPULayout(),
8079
zero_point_domain=ZeroPointDomain.INT,
80+
version=1,
8181
)
8282
)
8383
else:
84-
base_functions.append(int4_weight_only(group_size=32))
84+
base_functions.append(Int4WeightOnlyConfig(group_size=32, version=1))
8585
if device == "cuda" and not is_ROCM():
8686
base_functions.append(
8787
int8_dynamic_activation_int4_weight(
@@ -118,7 +118,7 @@ def test_tensor_core_layout_transpose(self):
118118
linear = torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
119119
t = linear.weight
120120
shape = t.shape
121-
apply_int4_weight_only_quant = int4_weight_only(group_size=32)
121+
apply_int4_weight_only_quant = Int4WeightOnlyConfig(group_size=32, version=1)
122122
quantize_(linear, apply_int4_weight_only_quant)
123123
ql = linear
124124
aqt = ql.weight
@@ -353,7 +353,7 @@ def test_slice_int4wo(self, device, dtype):
353353
# out_feature not divisible by 8
354354
# to test slice + padding for int4 weight only quantization
355355
dummy = nn.Linear(256, 321, dtype=dtype, device=device)
356-
quantize_(dummy, Int4WeightOnlyConfig())
356+
quantize_(dummy, Int4WeightOnlyConfig(version=1))
357357
# make sure these run without error
358358
_ = dummy.weight.narrow(0, 0, 64)
359359
_ = dummy.weight.narrow(1, 0, 128)
@@ -467,7 +467,7 @@ def test_slice_and_copy_int4wo(self, device, dtype):
467467
l.weight = torch.nn.Parameter(
468468
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
469469
)
470-
quantize_(l, Int4WeightOnlyConfig())
470+
quantize_(l, Int4WeightOnlyConfig(version=1))
471471
param = l.weight
472472
param_data = param.data
473473
param_data = param_data.narrow(0, 0, 512)
@@ -483,7 +483,7 @@ def test_slice_and_copy_int4wo(self, device, dtype):
483483

484484
# dummy_l has random input (shouldn't be 0)
485485
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
486-
quantize_(dummy_l, Int4WeightOnlyConfig())
486+
quantize_(dummy_l, Int4WeightOnlyConfig(version=1))
487487
quantized = dummy_l.weight
488488
quantized = quantized.narrow(0, 0, 512)
489489

@@ -502,7 +502,7 @@ def test_mm_int4wo(self, device, dtype):
502502

503503
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
504504
l.weight = torch.nn.Parameter(weight)
505-
quantize_(l, Int4WeightOnlyConfig())
505+
quantize_(l, Int4WeightOnlyConfig(version=1))
506506
# weight shape: 1024 x 512
507507
weight = l.weight
508508

test/integration/test_integration.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,17 +135,23 @@ def _int4wo_api(mod, use_hqq=False):
135135
quantize_(
136136
mod,
137137
int4_weight_only(
138-
layout=Int4CPULayout(), use_hqq=use_hqq, set_inductor_config=False
138+
layout=Int4CPULayout(),
139+
use_hqq=use_hqq,
140+
set_inductor_config=False,
141+
version=1,
139142
),
140143
)
141144
unwrap_tensor_subclass(mod)
142145
elif check_xpu_version(next(mod.parameters()).device):
143146
quantize_(
144-
mod, int4_weight_only(layout=Int4XPULayout()), set_inductor_config=False
147+
mod,
148+
int4_weight_only(layout=Int4XPULayout()),
149+
set_inductor_config=False,
150+
version=1,
145151
)
146152
unwrap_tensor_subclass(mod)
147153
else:
148-
quantize_(mod, int4_weight_only(set_inductor_config=False))
154+
quantize_(mod, int4_weight_only(set_inductor_config=False, version=1))
149155

150156

151157
def _int8da_int4w_api(mod):
@@ -1077,7 +1083,7 @@ def test_int4_weight_only_quant_subclass_api_grouped(self, device, dtype):
10771083
):
10781084
for groupsize in [64, 32]:
10791085
for layout in layout_list:
1080-
kwargs = {"groupsize": groupsize, "layout": layout}
1086+
kwargs = {"groupsize": groupsize, "layout": layout, "version": 1}
10811087

10821088
def api(mod):
10831089
kwargs_copy = kwargs.copy()

test/integration/test_load_and_run_checkpoint.py

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,22 @@
2424

2525
# please check model card for how to generate these models
2626

27-
_DEPRECATED_SINGLE_LINEAR_MODEL_NAMES = [
27+
# high precision model, used for testing config deprecation warning
28+
_HIGH_PRECISION_MODEL = "facebook/opt-125m"
29+
30+
_DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [
2831
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev
29-
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev"
32+
(
33+
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev",
34+
1,
35+
"Float8DynamicActivationFloat8WeightConfig",
36+
),
37+
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev
38+
(
39+
"torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev",
40+
1,
41+
"Int4WeightOnlyConfig",
42+
),
3043
]
3144

3245
_DEPRECATED_MODEL_INFO = [
@@ -36,15 +49,33 @@
3649
1,
3750
"Float8DynamicActivationFloat8WeightConfig",
3851
),
52+
# model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev
53+
(
54+
"torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev",
55+
1,
56+
"Int4WeightOnlyConfig",
57+
),
3958
]
4059

41-
_SINGLE_LINEAR_MODEL_NAMES = [
60+
_SINGLE_LINEAR_MODEL_INFO = [
4261
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev
43-
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
62+
(
63+
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
64+
2,
65+
"Float8DynamicActivationFloat8WeightConfig",
66+
),
4467
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev
45-
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
68+
(
69+
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
70+
2,
71+
"Int4WeightOnlyConfig",
72+
),
4673
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev
47-
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
74+
(
75+
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
76+
2,
77+
"Int4WeightOnlyConfig",
78+
),
4879
]
4980

5081

@@ -55,7 +86,9 @@
5586
"Skipping the test in fbcode for now, not sure how to download from transformers",
5687
)
5788
class TestLoadAndRunCheckpoint(TestCase):
58-
def _test_single_linear_helper(self, model_name):
89+
def _test_single_linear_helper(
90+
self, model_name, version, config_name, is_deprecated
91+
):
5992
from huggingface_hub import hf_hub_download
6093

6194
downloaded_model = hf_hub_download(model_name, filename="model.pt")
@@ -69,8 +102,20 @@ def _test_single_linear_helper(self, model_name):
69102
model = torch.nn.Sequential(
70103
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
71104
)
72-
with open(downloaded_model, "rb") as f:
105+
106+
with (
107+
open(downloaded_model, "rb") as f,
108+
warnings.catch_warnings(record=True) as caught_warnings,
109+
):
73110
model.load_state_dict(torch.load(f), assign=True)
111+
if is_deprecated:
112+
assert any(
113+
f"Models quantized with version {version} of {config_name} is deprecated"
114+
in str(w.message)
115+
for w in caught_warnings
116+
), (
117+
f"Didn't get expected warning message for deprecation for model: {model_name}"
118+
)
74119

75120
downloaded_example_inputs = hf_hub_download(
76121
model_name, filename="model_inputs.pt"
@@ -84,17 +129,23 @@ def _test_single_linear_helper(self, model_name):
84129
output = model(*example_inputs)
85130
self.assertTrue(torch.equal(output, ref_output))
86131

87-
@common_utils.parametrize("model_name", _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES)
88-
def test_deprecated_single_linear(self, model_name):
89-
self._test_single_linear_helper(model_name)
132+
@common_utils.parametrize("model_info", _DEPRECATED_SINGLE_LINEAR_MODEL_INFO)
133+
def test_deprecated_single_linear(self, model_info):
134+
model_name, version, config_name = model_info
135+
self._test_single_linear_helper(
136+
model_name, version, config_name, is_deprecated=True
137+
)
90138

91-
@common_utils.parametrize("model_name", _SINGLE_LINEAR_MODEL_NAMES)
92-
def test_single_linear(self, model_name):
139+
@common_utils.parametrize("model_info", _SINGLE_LINEAR_MODEL_INFO)
140+
def test_single_linear(self, model_info):
93141
"""Test that we can load and run the quantized linear checkpoint with saved sample input
94142
and match the saved output, to make sure there is no BC breaking changes
95143
when we make changes to tensor subclass implementations
96144
"""
97-
self._test_single_linear_helper(model_name)
145+
model_name, version, config_name = model_info
146+
self._test_single_linear_helper(
147+
model_name, version, config_name, is_deprecated=False
148+
)
98149

99150
@common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO)
100151
def test_deprecated_hf_models(self, model_info):
@@ -109,17 +160,23 @@ def test_deprecated_hf_models(self, model_info):
109160
torch_dtype="bfloat16",
110161
device_map="cuda:0",
111162
)
163+
# version mismatch check in config.py
112164
assert any(
113165
"Stored version is not the same as current default version of the config"
114166
in str(w.message)
115167
for w in caught_warnings
116-
), "Didn't get expected warning message for version mismatch"
168+
), (
169+
f"Didn't get expected warning message for version mismatch for config {config_name}, model {model_name}"
170+
)
117171

172+
# checkpoint deprecation
118173
assert any(
119-
f"Models quantized with version 1 of {config_name} is deprecated"
174+
f"Models quantized with version {version} of {config_name} is deprecated"
120175
in str(w.message)
121176
for w in caught_warnings
122-
), "Didn't get expected warning message for deprecation"
177+
), (
178+
f"Didn't get expected warning message for deprecation for model {model_name}"
179+
)
123180
assert isinstance(quantized_model.config.quantization_config, TorchAoConfig)
124181
assert (
125182
quantized_model.config.quantization_config.quant_type.version == version
@@ -139,7 +196,8 @@ def test_deprecated_hf_models(self, model_info):
139196
return_tensors="pt",
140197
).to("cuda")
141198
generated_ids = quantized_model.generate(
142-
**inputs, max_new_tokens=128, temperature=0
199+
**inputs,
200+
max_new_tokens=128,
143201
)
144202

145203
downloaded_output = hf_hub_download(model_name, filename="model_output.pt")
@@ -153,6 +211,23 @@ def test_deprecated_hf_models(self, model_info):
153211
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
154212
)
155213

214+
# make sure we throw warning for config deprecation
215+
with warnings.catch_warnings(record=True) as caught_warnings:
216+
_ = AutoModelForCausalLM.from_pretrained(
217+
_HIGH_PRECISION_MODEL,
218+
torch_dtype="bfloat16",
219+
device_map="cuda:0",
220+
quantization_config=quantized_model.config.quantization_config,
221+
)
222+
# config version deprecation in quant_api.py
223+
assert any(
224+
f"Config Deprecation: version {version} of {config_name} is deprecated and will no longer be supported in a future release"
225+
in str(w.message)
226+
for w in caught_warnings
227+
), (
228+
f"Didn't get expected warning message for version deprecation for config {config_name}, model {model_name}"
229+
)
230+
156231

157232
common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint)
158233

0 commit comments

Comments
 (0)