Skip to content

Commit 5301a7e

Browse files
committed
Bump int4 weight only config version to 2
Summary: Current Int4WeightOnlyConfig has version 1 and 2, and default is 1, this PR changes the default to 2 and made modification to callsites. For the Int4WeightOnlyConfig that's using the old configuration, we added explicit `version=1`, we can migrate the callsite to use the version 2 separately For READMEs we migrate the usage to version 2 directly Deprecation: TODO Test Plan: Regression tests: python test/dtypes/test_affine_quantized.py python test/quantization/test_quant_api.py python test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tensor.py python test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py Reviewers: Subscribers: Tasks: Tags:
1 parent c452495 commit 5301a7e

File tree

18 files changed

+133
-49
lines changed

18 files changed

+133
-49
lines changed

.github/scripts/torchao_model_releases/quantize_and_upload.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def _untie_weights_and_save_locally(model_id):
206206

207207
_int4_quant_code = """
208208
from torchao.quantization import Int4WeightOnlyConfig
209-
quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq", version=2)
209+
quant_config = Int4WeightOnlyConfig(group_size=128, packing_format="tile_packed_to_4d", int4_choose_qparams_algorithm="hqq")
210210
quantization_config = TorchAoConfig(quant_type=quant_config)
211211
quantized_model = AutoModelForCausalLM.from_pretrained(model_to_quantize, device_map="auto", torch_dtype=torch.bfloat16, quantization_config=quantization_config)
212212
tokenizer = AutoTokenizer.from_pretrained(model_id)
@@ -256,7 +256,7 @@ def _untie_weights_and_save_locally(model_id):
256256
)
257257
tokenizer = AutoTokenizer.from_pretrained(model_id)
258258
259-
base_config = Int4WeightOnlyConfig(group_size=128, version=2)
259+
base_config = Int4WeightOnlyConfig(group_size=128)
260260
quant_config = AWQConfig(base_config, step="prepare")
261261
quantize_(
262262
model,
@@ -635,7 +635,6 @@ def quantize_and_upload(
635635
group_size=128,
636636
packing_format="tile_packed_to_4d",
637637
int4_choose_qparams_algorithm="hqq",
638-
version=2,
639638
),
640639
"INT8-INT4": ModuleFqnToConfig(
641640
{
@@ -669,7 +668,7 @@ def quantize_and_upload(
669668
)
670669
tokenizer = AutoTokenizer.from_pretrained(model_id)
671670

672-
base_config = Int4WeightOnlyConfig(group_size=128, version=2)
671+
base_config = Int4WeightOnlyConfig(group_size=128)
673672
quant_config = AWQConfig(base_config, step="prepare")
674673
quantize_(
675674
model,

test/dtypes/test_affine_quantized.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
from torchao.quantization import (
2929
FbgemmConfig,
3030
GemliteUIntXWeightOnlyConfig,
31-
Int4WeightOnlyConfig,
3231
Int8DynamicActivationInt8WeightConfig,
3332
float8_weight_only,
3433
int4_dynamic_activation_int4_weight,
@@ -354,7 +353,7 @@ def test_slice_int4wo(self, device, dtype):
354353
# out_feature not divisible by 8
355354
# to test slice + padding for int4 weight only quantization
356355
dummy = nn.Linear(256, 321, dtype=dtype, device=device)
357-
quantize_(dummy, Int4WeightOnlyConfig(version=1))
356+
quantize_(dummy, int4_weight_only(version=1))
358357
# make sure these run without error
359358
_ = dummy.weight.narrow(0, 0, 64)
360359
_ = dummy.weight.narrow(1, 0, 128)
@@ -468,7 +467,7 @@ def test_slice_and_copy_int4wo(self, device, dtype):
468467
l.weight = torch.nn.Parameter(
469468
torch.zeros(1024, 1024, dtype=torch.bfloat16, device="cuda")
470469
)
471-
quantize_(l, Int4WeightOnlyConfig(version=1))
470+
quantize_(l, int4_weight_only(version=1))
472471
param = l.weight
473472
param_data = param.data
474473
param_data = param_data.narrow(0, 0, 512)
@@ -484,7 +483,7 @@ def test_slice_and_copy_int4wo(self, device, dtype):
484483

485484
# dummy_l has random input (shouldn't be 0)
486485
dummy_l = torch.nn.Linear(1024, 1024).to("cuda").to(torch.bfloat16)
487-
quantize_(dummy_l, Int4WeightOnlyConfig(version=1))
486+
quantize_(dummy_l, int4_weight_only(version=1))
488487
quantized = dummy_l.weight
489488
quantized = quantized.narrow(0, 0, 512)
490489

@@ -503,7 +502,7 @@ def test_mm_int4wo(self, device, dtype):
503502

504503
l = torch.nn.Linear(512, 1024).to(device).to(dtype)
505504
l.weight = torch.nn.Parameter(weight)
506-
quantize_(l, Int4WeightOnlyConfig(version=1))
505+
quantize_(l, int4_weight_only(version=1))
507506
# weight shape: 1024 x 512
508507
weight = l.weight
509508

test/integration/test_load_and_run_checkpoint.py

Lines changed: 93 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,22 @@
2424

2525
# please check model card for how to generate these models
2626

27-
_DEPRECATED_SINGLE_LINEAR_MODEL_NAMES = [
27+
# high precision model, used for testing config deprecation warning
28+
_HIGH_PRECISION_MODEL = "facebook/opt-125m"
29+
30+
_DEPRECATED_SINGLE_LINEAR_MODEL_INFO = [
2831
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev
29-
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev"
32+
(
33+
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v1-0.13.dev",
34+
1,
35+
"Float8DynamicActivationFloat8WeightConfig",
36+
),
37+
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev
38+
(
39+
"torchao-testing/single-linear-Int4WeightOnlyConfig-v1-0.14.dev",
40+
1,
41+
"Int4WeightOnlyConfig",
42+
),
3043
]
3144

3245
_DEPRECATED_MODEL_INFO = [
@@ -36,15 +49,33 @@
3649
1,
3750
"Float8DynamicActivationFloat8WeightConfig",
3851
),
52+
# model card: https://huggingface.co/torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev
53+
(
54+
"torchao-testing/opt-125m-Int4WeightOnlyConfig-v1-0.14.dev",
55+
1,
56+
"Int4WeightOnlyConfig",
57+
),
3958
]
4059

41-
_SINGLE_LINEAR_MODEL_NAMES = [
60+
_SINGLE_LINEAR_MODEL_INFO = [
4261
# model card: https://huggingface.co/torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev
43-
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
62+
(
63+
"torchao-testing/single-linear-Float8DynamicActivationFloat8WeightConfig-v2-0.13.dev",
64+
2,
65+
"Float8DynamicActivationFloat8WeightConfig",
66+
),
4467
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev
45-
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
68+
(
69+
"torchao-testing/single-linear-Int4WeightOnlyConfig-v2-0.13.dev",
70+
2,
71+
"Int4WeightOnlyConfig",
72+
),
4673
# model card: https://huggingface.co/torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev
47-
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
74+
(
75+
"torchao-testing/single-linear-Int4WeightOnlyConfig-preshuffled-v2-0.13.dev",
76+
2,
77+
"Int4WeightOnlyConfig",
78+
),
4879
]
4980

5081

@@ -55,7 +86,9 @@
5586
"Skipping the test in fbcode for now, not sure how to download from transformers",
5687
)
5788
class TestLoadAndRunCheckpoint(TestCase):
58-
def _test_single_linear_helper(self, model_name):
89+
def _test_single_linear_helper(
90+
self, model_name, version, config_name, is_deprecated
91+
):
5992
from huggingface_hub import hf_hub_download
6093

6194
downloaded_model = hf_hub_download(model_name, filename="model.pt")
@@ -69,8 +102,20 @@ def _test_single_linear_helper(self, model_name):
69102
model = torch.nn.Sequential(
70103
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
71104
)
72-
with open(downloaded_model, "rb") as f:
105+
106+
with (
107+
open(downloaded_model, "rb") as f,
108+
warnings.catch_warnings(record=True) as caught_warnings,
109+
):
73110
model.load_state_dict(torch.load(f), assign=True)
111+
if is_deprecated:
112+
assert any(
113+
f"Models quantized with version {version} of {config_name} is deprecated"
114+
in str(w.message)
115+
for w in caught_warnings
116+
), (
117+
f"Didn't get expected warning message for deprecation for model: {model_name}"
118+
)
74119

75120
downloaded_example_inputs = hf_hub_download(
76121
model_name, filename="model_inputs.pt"
@@ -84,17 +129,23 @@ def _test_single_linear_helper(self, model_name):
84129
output = model(*example_inputs)
85130
self.assertTrue(torch.equal(output, ref_output))
86131

87-
@common_utils.parametrize("model_name", _DEPRECATED_SINGLE_LINEAR_MODEL_NAMES)
88-
def test_deprecated_single_linear(self, model_name):
89-
self._test_single_linear_helper(model_name)
132+
@common_utils.parametrize("model_info", _DEPRECATED_SINGLE_LINEAR_MODEL_INFO)
133+
def test_deprecated_single_linear(self, model_info):
134+
model_name, version, config_name = model_info
135+
self._test_single_linear_helper(
136+
model_name, version, config_name, is_deprecated=True
137+
)
90138

91-
@common_utils.parametrize("model_name", _SINGLE_LINEAR_MODEL_NAMES)
92-
def test_single_linear(self, model_name):
139+
@common_utils.parametrize("model_info", _SINGLE_LINEAR_MODEL_INFO)
140+
def test_single_linear(self, model_info):
93141
"""Test that we can load and run the quantized linear checkpoint with saved sample input
94142
and match the saved output, to make sure there is no BC breaking changes
95143
when we make changes to tensor subclass implementations
96144
"""
97-
self._test_single_linear_helper(model_name)
145+
model_name, version, config_name = model_info
146+
self._test_single_linear_helper(
147+
model_name, version, config_name, is_deprecated=False
148+
)
98149

99150
@common_utils.parametrize("model_info", _DEPRECATED_MODEL_INFO)
100151
def test_deprecated_hf_models(self, model_info):
@@ -109,17 +160,23 @@ def test_deprecated_hf_models(self, model_info):
109160
torch_dtype="bfloat16",
110161
device_map="cuda:0",
111162
)
163+
# version mismatch check in config.py
112164
assert any(
113165
"Stored version is not the same as current default version of the config"
114166
in str(w.message)
115167
for w in caught_warnings
116-
), "Didn't get expected warning message for version mismatch"
168+
), (
169+
f"Didn't get expected warning message for version mismatch for config {config_name}, model {model_name}"
170+
)
117171

172+
# checkpoint deprecation
118173
assert any(
119-
f"Models quantized with version 1 of {config_name} is deprecated"
174+
f"Models quantized with version {version} of {config_name} is deprecated"
120175
in str(w.message)
121176
for w in caught_warnings
122-
), "Didn't get expected warning message for deprecation"
177+
), (
178+
f"Didn't get expected warning message for deprecation for model {model_name}"
179+
)
123180
assert isinstance(quantized_model.config.quantization_config, TorchAoConfig)
124181
assert (
125182
quantized_model.config.quantization_config.quant_type.version == version
@@ -139,7 +196,8 @@ def test_deprecated_hf_models(self, model_info):
139196
return_tensors="pt",
140197
).to("cuda")
141198
generated_ids = quantized_model.generate(
142-
**inputs, max_new_tokens=128, temperature=0
199+
**inputs,
200+
max_new_tokens=128,
143201
)
144202

145203
downloaded_output = hf_hub_download(model_name, filename="model_output.pt")
@@ -153,6 +211,23 @@ def test_deprecated_hf_models(self, model_info):
153211
generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False
154212
)
155213

214+
# make sure we throw warning for config deprecation
215+
with warnings.catch_warnings(record=True) as caught_warnings:
216+
_ = AutoModelForCausalLM.from_pretrained(
217+
_HIGH_PRECISION_MODEL,
218+
torch_dtype="bfloat16",
219+
device_map="cuda:0",
220+
quantization_config=quantized_model.config.quantization_config,
221+
)
222+
# config version deprecation in quant_api.py
223+
assert any(
224+
f"Config Deprecation: version {version} of {config_name} is deprecated and will no longer be supported in a future release"
225+
in str(w.message)
226+
for w in caught_warnings
227+
), (
228+
f"Didn't get expected warning message for version deprecation for config {config_name}, model {model_name}"
229+
)
230+
156231

157232
common_utils.instantiate_parametrized_tests(TestLoadAndRunCheckpoint)
158233

test/prototype/test_awq.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_awq_functionality(self):
7373
m = ToyLinearModel(l1, l2, l3).eval().to(original_dtype).to(device)
7474

7575
# baseline quantization
76-
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
76+
base_config = Int4WeightOnlyConfig(group_size=group_size)
7777
m_baseline = copy.deepcopy(m)
7878
quantize_(m_baseline, base_config)
7979

@@ -123,7 +123,7 @@ def test_awq_loading(self):
123123
calibration_data = dataset[:n_calibration_examples]
124124

125125
# calibrate
126-
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
126+
base_config = Int4WeightOnlyConfig(group_size=group_size)
127127
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
128128
quantize_(m, quant_config)
129129

@@ -177,7 +177,7 @@ def test_awq_loading_vllm(self):
177177
calibration_data = dataset[:n_calibration_examples]
178178

179179
# calibrate
180-
base_config = Int4WeightOnlyConfig(group_size=group_size, version=2)
180+
base_config = Int4WeightOnlyConfig(group_size=group_size)
181181
quant_config = AWQConfig(base_config, step=AWQStep.PREPARE)
182182
quantize_(m, quant_config)
183183

test/quantization/quantize_/workflows/int4/test_int4_marlin_sparse_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@
2727
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
2828
group_size=128,
2929
int4_packing_format="marlin_sparse",
30-
version=2,
3130
)
3231

3332

test/quantization/quantize_/workflows/int4/test_int4_opaque_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def get_config(group_size):
2929
return Int4WeightOnlyConfig(
3030
group_size=group_size,
3131
int4_packing_format="opaque",
32-
version=2,
3332
)
3433

3534

test/quantization/quantize_/workflows/int4/test_int4_plain_int32_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ def get_config(group_size):
2929
return Int4WeightOnlyConfig(
3030
group_size=group_size,
3131
int4_packing_format="plain_int32",
32-
version=2,
3332
)
3433

3534

test/quantization/quantize_/workflows/int4/test_int4_preshuffled_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
BF16_ACT_CONFIG = Int4WeightOnlyConfig(
3131
group_size=128,
3232
int4_packing_format="preshuffled",
33-
version=2,
3433
)
3534

3635
# only 128 group_size is supported

test/quantization/quantize_/workflows/int4/test_int4_tensor.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@ def setUp(self):
3535
self.config = Int4WeightOnlyConfig(
3636
group_size=128,
3737
int4_packing_format="plain",
38-
version=2,
3938
)
4039
self.GPU_DEVICES = ["cuda"] if torch.cuda.is_available() else []
4140

test/quantization/quantize_/workflows/int4/test_int4_tile_packed_to_4d_tensor.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,12 @@
2525
INT4_CONFIG = Int4WeightOnlyConfig(
2626
group_size=128,
2727
int4_packing_format="tile_packed_to_4d",
28-
version=2,
2928
)
3029

3130
INT4_HQQ_CONFIG = Int4WeightOnlyConfig(
3231
group_size=128,
3332
int4_packing_format="tile_packed_to_4d",
3433
int4_choose_qparams_algorithm="hqq",
35-
version=2,
3634
)
3735

3836

0 commit comments

Comments
 (0)