Skip to content

Commit fb7c837

Browse files
authored
Avoid normalization layers in HF's quantization_config (#3030)
* Avoid normalization layers in HF's quantization_config * Add TestTorchAoConfigIntegration * Use PreTrainedModel.from_pretrained
1 parent bc72e1c commit fb7c837

File tree

2 files changed

+141
-63
lines changed

2 files changed

+141
-63
lines changed

test/prototype/test_parq.py

Lines changed: 138 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7+
import tempfile
78
import unittest
89
from typing import Optional
910

@@ -27,9 +28,13 @@
2728
UnifQuantizer,
2829
UnifTorchaoQuantizer,
2930
)
30-
from torchao.prototype.parq.quant.config_torchao import TRANSFORMERS_AVAIL, _is_hf_model
31+
from torchao.prototype.parq.quant.config_torchao import (
32+
TRANSFORMERS_AVAIL,
33+
_attach_hf_quantization_config,
34+
_is_hf_model,
35+
)
3136
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
32-
from torchao.quantization.granularity import PerGroup
37+
from torchao.quantization.granularity import PerAxis, PerGroup
3338
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
3439
from torchao.quantization.quant_api import (
3540
Int4WeightOnlyConfig,
@@ -50,6 +55,84 @@
5055
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5156

5257

58+
class M(nn.Module):
59+
_tied_weights_keys: list[str] = []
60+
61+
def __init__(
62+
self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False
63+
):
64+
nn.Module.__init__(self)
65+
self.embed_tokens = nn.Embedding(k, m) if embedding else nn.Identity()
66+
self.linear1 = nn.Linear(m, n, bias=bias)
67+
self.linear2 = nn.Linear(n, k, bias=bias)
68+
self.relu = nn.ReLU()
69+
self.sigmoid = nn.Sigmoid()
70+
71+
if embedding and tied_weights:
72+
assert self.embed_tokens.weight.shape == self.linear2.weight.shape
73+
self.tie_weights()
74+
self._tied_weights_keys.append("linear2.weight")
75+
76+
def tie_weights(self):
77+
self.linear2.weight = self.embed_tokens.weight
78+
79+
def example_inputs(self, device=None):
80+
if isinstance(self.embed_tokens, nn.Identity):
81+
inputs = torch.randn(1, self.linear1.in_features, device=device)
82+
else:
83+
k = self.embed_tokens.num_embeddings
84+
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
85+
return inputs
86+
87+
def forward(self, x):
88+
x = self.embed_tokens(x)
89+
x = self.relu(self.linear1(x))
90+
x = self.sigmoid(self.linear2(x))
91+
return x
92+
93+
94+
if TRANSFORMERS_AVAIL:
95+
from transformers import PretrainedConfig, PreTrainedModel, TorchAoConfig
96+
97+
class MConfig(PretrainedConfig):
98+
def __init__(
99+
self,
100+
m=256,
101+
n=128,
102+
k=16,
103+
bias=False,
104+
embedding=True,
105+
tied_weights=False,
106+
**kwargs,
107+
):
108+
super().__init__(**kwargs)
109+
self.m = m
110+
self.n = n
111+
self.k = k
112+
self.bias = bias
113+
self.embedding = embedding
114+
self.tied_weights = tied_weights
115+
116+
class PreTrainedM(M, PreTrainedModel):
117+
base_model_prefix = "base"
118+
config_class = MConfig
119+
120+
def __init__(self, config: MConfig):
121+
PreTrainedModel.__init__(self, config)
122+
M.__init__(
123+
self,
124+
m=config.m,
125+
n=config.n,
126+
k=config.k,
127+
bias=config.bias,
128+
embedding=config.embedding,
129+
tied_weights=config.tied_weights,
130+
)
131+
132+
def get_input_embeddings(self) -> nn.Module:
133+
return self.embed_tokens
134+
135+
53136
def split_param_groups(model) -> tuple[list, list, list]:
54137
params_quant, params_embed, params_no_quant = [], [], []
55138

@@ -191,49 +274,9 @@ def apply_activation_quantization(
191274
pass
192275

193276

194-
class M(nn.Module):
195-
_tied_weights_keys: list[str] = []
196-
197-
def __init__(
198-
self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False
199-
):
200-
super().__init__()
201-
self.embedding = nn.Embedding(k, m) if embedding else nn.Identity()
202-
self.linear1 = nn.Linear(m, n, bias=bias)
203-
self.linear2 = nn.Linear(n, k, bias=bias)
204-
self.relu = nn.ReLU()
205-
self.sigmoid = nn.Sigmoid()
206-
207-
if embedding and tied_weights:
208-
assert self.embedding.weight.shape == self.linear2.weight.shape
209-
self.linear2.weight = self.embedding.weight
210-
self._tied_weights_keys.append("linear2.weight")
211-
212-
def reset_parameters(self):
213-
for module in (self.linear1, self.linear2):
214-
nn.init.xavier_uniform_(module.weight)
215-
if module.bias is not None:
216-
nn.init.zeros_(module.bias)
217-
218-
def example_inputs(self, device=None):
219-
if isinstance(self.embedding, nn.Identity):
220-
inputs = torch.randn(1, self.linear1.in_features, device=device)
221-
else:
222-
k = self.embedding.num_embeddings
223-
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
224-
return inputs
225-
226-
def forward(self, x):
227-
x = self.embedding(x)
228-
x = self.relu(self.linear1(x))
229-
x = self.sigmoid(self.linear2(x))
230-
return x
231-
232-
233277
class TestPARQuantization(common_utils.TestCase):
234278
def setUp(self):
235279
torch.manual_seed(123)
236-
self.model = M(bias=True).to(_DEVICE)
237280

238281
@common_utils.parametrize("b", [0, 1, 2, 4])
239282
@common_utils.parametrize("unif_quant", [True, False])
@@ -242,13 +285,13 @@ def setUp(self):
242285
def test_parq_train_loop(
243286
self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quantizer=False
244287
):
245-
self.model.reset_parameters()
288+
model = M(bias=True).to(_DEVICE)
246289
if unif_quant:
247290
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
248291
else:
249292
quantizer = LSBQuantizer()
250293
param_groups = build_param_groups(
251-
self.model, b, quantizer=quantizer if per_group_quantizer else None
294+
model, b, quantizer=quantizer if per_group_quantizer else None
252295
)
253296
base_optimizer = torch.optim.AdamW(param_groups)
254297

@@ -257,12 +300,12 @@ def test_parq_train_loop(
257300
)
258301
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
259302
for _ in range(3):
260-
x = self.model.example_inputs(device=_DEVICE)
261-
out = self.model(x)
303+
x = model.example_inputs(device=_DEVICE)
304+
out = model(x)
262305
out.sum().backward()
263306
optimizer.step()
264307

265-
for child in self.model.children():
308+
for child in model.children():
266309
if isinstance(child, nn.Linear):
267310
self.assertEqual(
268311
child.weight.unique().numel(), quantizer.get_quant_size(b)
@@ -281,7 +324,6 @@ def setUp(self):
281324
@common_utils.parametrize("group_size", [32, 256])
282325
def test_int4_weight_only(self, group_size: int = 32):
283326
model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16)
284-
model.reset_parameters()
285327

286328
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
287329
config = Int4WeightOnlyConfig(group_size=group_size)
@@ -299,7 +341,6 @@ def test_int4_weight_only(self, group_size: int = 32):
299341
@common_utils.parametrize("group_size", [32, 512])
300342
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
301343
model = M(m=512, n=512).to(_DEVICE)
302-
model.reset_parameters()
303344

304345
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
305346
quantize_(
@@ -319,7 +360,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
319360
)
320361
def test_int4_weight_only_e2e(self, group_size: int = 32):
321362
model = M(m=512, n=512, embedding=False).to(torch.bfloat16).to(_DEVICE)
322-
model.reset_parameters()
323363

324364
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
325365
config = Int4WeightOnlyConfig(group_size=group_size)
@@ -339,7 +379,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
339379
@common_utils.parametrize("b", [2, 3, 4, 8])
340380
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
341381
model = M(m=512, n=512, embedding=False).to(_DEVICE)
342-
model.reset_parameters()
343382

344383
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
345384
config = IntxWeightOnlyConfig(
@@ -366,7 +405,6 @@ def setUp(self):
366405
@common_utils.parametrize("group_size", [32, 256])
367406
def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32):
368407
model = M(m=512, n=512).to(_DEVICE)
369-
model.reset_parameters()
370408

371409
quantizer_ref = UnifQuantizer()
372410
quantizer = StretchedUnifTorchaoQuantizer(b)
@@ -389,7 +427,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
389427
@common_utils.parametrize("group_size", [32, 512])
390428
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
391429
model = M(m=512, n=512).to(_DEVICE)
392-
model.reset_parameters()
393430

394431
quantizer = StretchedUnifTorchaoQuantizer(b)
395432

@@ -411,7 +448,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
411448
@common_utils.parametrize("b", [2, 3])
412449
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
413450
model = M(m=512, n=512, embedding=False).to(_DEVICE)
414-
model.reset_parameters()
415451

416452
quantizer = StretchedUnifTorchaoQuantizer(b)
417453

@@ -456,14 +492,16 @@ def test_intx_weight_only_tied_embed_linear(
456492
optimizer.torchao_convert(model)
457493
check_torchao_tensor_subclass(self, model)
458494
self.assertTrue(
459-
torch.equal(model.embedding.weight.qdata, model.linear2.weight.qdata)
495+
torch.equal(model.embed_tokens.weight.qdata, model.linear2.weight.qdata)
460496
)
461497

462498

463499
class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
464500
def setUp(self):
465501
torch.manual_seed(123)
466502

503+
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
504+
@unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers")
467505
@common_utils.parametrize("b", [2, 3, 4, 8])
468506
@common_utils.parametrize(
469507
"model_dtype", [torch.float16, torch.float32, torch.bfloat16]
@@ -475,7 +513,8 @@ def test_int8_dynamic_activation_intx_e2e(
475513
model_dtype: torch.dtype = torch.float32,
476514
group_size: int = 32,
477515
):
478-
model = M(embedding=False, bias=True).to(_DEVICE, dtype=model_dtype)
516+
config = MConfig(embedding=False, bias=True)
517+
model = PreTrainedM(config).to(_DEVICE, dtype=model_dtype)
479518
x = model.example_inputs(device=_DEVICE).to(model_dtype)
480519

481520
# reference model using native quantization
@@ -506,9 +545,6 @@ def test_int8_dynamic_activation_intx_e2e(
506545

507546
attach_hf_config = False
508547
if TRANSFORMERS_AVAIL:
509-
from transformers import PretrainedConfig
510-
511-
model.config = PretrainedConfig() # pretend this is a HF model
512548
attach_hf_config = _is_hf_model(model)
513549
self.assertTrue(attach_hf_config)
514550

@@ -530,6 +566,49 @@ def test_int8_dynamic_activation_intx_e2e(
530566
self.assertTrue(isinstance(torchao_config, config.__class__))
531567

532568

569+
class TestTorchAoConfigIntegration(common_utils.TestCase):
570+
@unittest.skipIf(torch.backends.mps.is_available(), "MPS not supported")
571+
@unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers")
572+
def test_tied_weights_quantization(self, b: int = 4):
573+
config = MConfig(m=128, n=128, tied_weights=True)
574+
model = PreTrainedM(config).to(_DEVICE)
575+
576+
quantizer = StretchedUnifTorchaoQuantizer(b)
577+
linear_config = StretchedIntxWeightConfig(
578+
b=b,
579+
quant_min=quantizer.quant_min,
580+
quant_max=quantizer.quant_max,
581+
granularity=PerAxis(0),
582+
)
583+
embed_config = IntxWeightOnlyConfig(
584+
weight_dtype=_BIT_WIDTH_TO_DTYPE[b], granularity=PerGroup(32)
585+
)
586+
module_to_config = {"_default": linear_config}
587+
configs = [embed_config]
588+
filter_fns = [lambda m: isinstance(m, nn.Embedding)]
589+
_attach_hf_quantization_config(model, filter_fns, configs, module_to_config)
590+
591+
quantization_config = getattr(model.config, "quantization_config", None)
592+
self.assertTrue(isinstance(quantization_config, TorchAoConfig))
593+
self.assertTrue(quantization_config.modules_to_not_convert == ["linear2"])
594+
595+
# Let HF apply quantize_ given quantization_config
596+
del model.config.quantization_config
597+
with tempfile.TemporaryDirectory() as tmp_dir:
598+
model.save_pretrained(tmp_dir, safe_serialization=False)
599+
model = PreTrainedM.from_pretrained(
600+
tmp_dir, quantization_config=quantization_config
601+
)
602+
603+
check_torchao_tensor_subclass(self, model.linear1)
604+
check_torchao_tensor_subclass(self, model.linear2, weight_only=True)
605+
check_torchao_tensor_subclass(self, model.embed_tokens, weight_only=True)
606+
607+
self.assertTrue(
608+
model.linear2.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
609+
)
610+
611+
533612
common_utils.instantiate_parametrized_tests(TestPARQuantization)
534613
common_utils.instantiate_parametrized_tests(TestUnifTorchaoQuantizer)
535614
common_utils.instantiate_parametrized_tests(TestInt8DynamicActivationTorchaoQuantizer)

torchao/prototype/parq/quant/config_torchao.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,17 +187,16 @@ def _attach_hf_quantization_config(
187187
if module_to_config is None:
188188
module_to_config = {}
189189

190-
seen_data_ptrs = set()
190+
tied_weights_keys = set(getattr(model, "_tied_weights_keys", []))
191191
modules_to_not_convert = []
192192
for name, module in model.named_modules():
193193
if not hasattr(module, "weight"):
194194
continue
195195

196-
data_ptr = module.weight.data_ptr()
197-
if data_ptr in seen_data_ptrs: # do not re-quantize tied weight
196+
# Do not quantize pointers to tied weights or normalization layers
197+
if f"{name}.weight" in tied_weights_keys or "norm" in name:
198198
modules_to_not_convert.append(name)
199199
continue
200-
seen_data_ptrs.add(data_ptr)
201200

202201
for i, filter_fn in enumerate(filter_fns):
203202
if filter_fn(module):

0 commit comments

Comments
 (0)