Skip to content

Commit 1572b34

Browse files
committed
Use PreTrainedModel.from_pretrained
1 parent 8634b2e commit 1572b34

File tree

1 file changed

+99
-81
lines changed

1 file changed

+99
-81
lines changed

test/prototype/test_parq.py

Lines changed: 99 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66
import copy
7+
import tempfile
78
import unittest
89
from typing import Optional
910

@@ -53,9 +54,83 @@
5354

5455
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5556

57+
58+
class M(nn.Module):
59+
_tied_weights_keys: list[str] = []
60+
61+
def __init__(
62+
self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False
63+
):
64+
nn.Module.__init__(self)
65+
self.embed_tokens = nn.Embedding(k, m) if embedding else nn.Identity()
66+
self.linear1 = nn.Linear(m, n, bias=bias)
67+
self.linear2 = nn.Linear(n, k, bias=bias)
68+
self.relu = nn.ReLU()
69+
self.sigmoid = nn.Sigmoid()
70+
71+
if embedding and tied_weights:
72+
assert self.embed_tokens.weight.shape == self.linear2.weight.shape
73+
self.tie_weights()
74+
self._tied_weights_keys.append("linear2.weight")
75+
76+
def tie_weights(self):
77+
self.linear2.weight = self.embed_tokens.weight
78+
79+
def example_inputs(self, device=None):
80+
if isinstance(self.embed_tokens, nn.Identity):
81+
inputs = torch.randn(1, self.linear1.in_features, device=device)
82+
else:
83+
k = self.embed_tokens.num_embeddings
84+
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
85+
return inputs
86+
87+
def forward(self, x):
88+
x = self.embed_tokens(x)
89+
x = self.relu(self.linear1(x))
90+
x = self.sigmoid(self.linear2(x))
91+
return x
92+
93+
5694
if TRANSFORMERS_AVAIL:
57-
from transformers import PretrainedConfig, TorchAoConfig
58-
from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer
95+
from transformers import PretrainedConfig, PreTrainedModel, TorchAoConfig
96+
97+
class MConfig(PretrainedConfig):
98+
def __init__(
99+
self,
100+
m=256,
101+
n=128,
102+
k=16,
103+
bias=False,
104+
embedding=True,
105+
tied_weights=False,
106+
**kwargs,
107+
):
108+
super().__init__(**kwargs)
109+
self.m = m
110+
self.n = n
111+
self.k = k
112+
self.bias = bias
113+
self.embedding = embedding
114+
self.tied_weights = tied_weights
115+
116+
class PreTrainedM(M, PreTrainedModel):
117+
base_model_prefix = "base"
118+
config_class = MConfig
119+
120+
def __init__(self, config: MConfig):
121+
PreTrainedModel.__init__(self, config)
122+
M.__init__(
123+
self,
124+
m=config.m,
125+
n=config.n,
126+
k=config.k,
127+
bias=config.bias,
128+
embedding=config.embedding,
129+
tied_weights=config.tied_weights,
130+
)
131+
132+
def get_input_embeddings(self) -> nn.Module:
133+
return self.embed_tokens
59134

60135

61136
def split_param_groups(model) -> tuple[list, list, list]:
@@ -199,55 +274,9 @@ def apply_activation_quantization(
199274
pass
200275

201276

202-
class M(nn.Module):
203-
_tied_weights_keys: list[str] = []
204-
205-
def __init__(
206-
self, m=256, n=128, k=16, bias=False, embedding=True, tied_weights=False
207-
):
208-
super().__init__()
209-
self.embedding = nn.Embedding(k, m) if embedding else nn.Identity()
210-
self.linear1 = nn.Linear(m, n, bias=bias)
211-
self.linear2 = nn.Linear(n, k, bias=bias)
212-
self.relu = nn.ReLU()
213-
self.sigmoid = nn.Sigmoid()
214-
215-
if embedding and tied_weights:
216-
assert self.embedding.weight.shape == self.linear2.weight.shape
217-
self.tie_weights()
218-
self._tied_weights_keys.append("linear2.weight")
219-
220-
def tie_weights(self):
221-
self.linear2.weight = self.embedding.weight
222-
223-
def reset_parameters(self):
224-
for module in (self.linear1, self.linear2):
225-
nn.init.xavier_uniform_(module.weight)
226-
if module.bias is not None:
227-
nn.init.zeros_(module.bias)
228-
229-
def example_inputs(self, device=None):
230-
if isinstance(self.embedding, nn.Identity):
231-
inputs = torch.randn(1, self.linear1.in_features, device=device)
232-
else:
233-
k = self.embedding.num_embeddings
234-
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
235-
return inputs
236-
237-
def get_input_embeddings(self) -> nn.Module:
238-
return self.embedding
239-
240-
def forward(self, x):
241-
x = self.embedding(x)
242-
x = self.relu(self.linear1(x))
243-
x = self.sigmoid(self.linear2(x))
244-
return x
245-
246-
247277
class TestPARQuantization(common_utils.TestCase):
248278
def setUp(self):
249279
torch.manual_seed(123)
250-
self.model = M(bias=True).to(_DEVICE)
251280

252281
@common_utils.parametrize("b", [0, 1, 2, 4])
253282
@common_utils.parametrize("unif_quant", [True, False])
@@ -256,13 +285,13 @@ def setUp(self):
256285
def test_parq_train_loop(
257286
self, b: int = 2, unif_quant=True, hard_prox=True, per_group_quantizer=False
258287
):
259-
self.model.reset_parameters()
288+
model = M(bias=True).to(_DEVICE)
260289
if unif_quant:
261290
quantizer = TernaryUnifQuantizer() if b == 0 else UnifQuantizer()
262291
else:
263292
quantizer = LSBQuantizer()
264293
param_groups = build_param_groups(
265-
self.model, b, quantizer=quantizer if per_group_quantizer else None
294+
model, b, quantizer=quantizer if per_group_quantizer else None
266295
)
267296
base_optimizer = torch.optim.AdamW(param_groups)
268297

@@ -271,12 +300,12 @@ def test_parq_train_loop(
271300
)
272301
optimizer = QuantOptimizer(base_optimizer, quantizer, prox_map)
273302
for _ in range(3):
274-
x = self.model.example_inputs(device=_DEVICE)
275-
out = self.model(x)
303+
x = model.example_inputs(device=_DEVICE)
304+
out = model(x)
276305
out.sum().backward()
277306
optimizer.step()
278307

279-
for child in self.model.children():
308+
for child in model.children():
280309
if isinstance(child, nn.Linear):
281310
self.assertEqual(
282311
child.weight.unique().numel(), quantizer.get_quant_size(b)
@@ -295,7 +324,6 @@ def setUp(self):
295324
@common_utils.parametrize("group_size", [32, 256])
296325
def test_int4_weight_only(self, group_size: int = 32):
297326
model = M(m=512, n=512).to(_DEVICE, dtype=torch.bfloat16)
298-
model.reset_parameters()
299327

300328
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
301329
config = Int4WeightOnlyConfig(group_size=group_size)
@@ -313,7 +341,6 @@ def test_int4_weight_only(self, group_size: int = 32):
313341
@common_utils.parametrize("group_size", [32, 512])
314342
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
315343
model = M(m=512, n=512).to(_DEVICE)
316-
model.reset_parameters()
317344

318345
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
319346
quantize_(
@@ -333,7 +360,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
333360
)
334361
def test_int4_weight_only_e2e(self, group_size: int = 32):
335362
model = M(m=512, n=512, embedding=False).to(torch.bfloat16).to(_DEVICE)
336-
model.reset_parameters()
337363

338364
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
339365
config = Int4WeightOnlyConfig(group_size=group_size)
@@ -353,7 +379,6 @@ def test_int4_weight_only_e2e(self, group_size: int = 32):
353379
@common_utils.parametrize("b", [2, 3, 4, 8])
354380
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
355381
model = M(m=512, n=512, embedding=False).to(_DEVICE)
356-
model.reset_parameters()
357382

358383
m_ref = copy.deepcopy(model).eval().to(_DEVICE)
359384
config = IntxWeightOnlyConfig(
@@ -380,7 +405,6 @@ def setUp(self):
380405
@common_utils.parametrize("group_size", [32, 256])
381406
def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32):
382407
model = M(m=512, n=512).to(_DEVICE)
383-
model.reset_parameters()
384408

385409
quantizer_ref = UnifQuantizer()
386410
quantizer = StretchedUnifTorchaoQuantizer(b)
@@ -403,7 +427,6 @@ def test_intx_weight_only_parq_equivalent(self, b: int = 2, group_size: int = 32
403427
@common_utils.parametrize("group_size", [32, 512])
404428
def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
405429
model = M(m=512, n=512).to(_DEVICE)
406-
model.reset_parameters()
407430

408431
quantizer = StretchedUnifTorchaoQuantizer(b)
409432

@@ -425,7 +448,6 @@ def test_intx_weight_only(self, b: int = 2, group_size: int = 32):
425448
@common_utils.parametrize("b", [2, 3])
426449
def test_intx_weight_only_e2e(self, b: int = 2, group_size: int = 32):
427450
model = M(m=512, n=512, embedding=False).to(_DEVICE)
428-
model.reset_parameters()
429451

430452
quantizer = StretchedUnifTorchaoQuantizer(b)
431453

@@ -470,14 +492,16 @@ def test_intx_weight_only_tied_embed_linear(
470492
optimizer.torchao_convert(model)
471493
check_torchao_tensor_subclass(self, model)
472494
self.assertTrue(
473-
torch.equal(model.embedding.weight.qdata, model.linear2.weight.qdata)
495+
torch.equal(model.embed_tokens.weight.qdata, model.linear2.weight.qdata)
474496
)
475497

476498

477499
class TestInt8DynamicActivationTorchaoQuantizer(common_utils.TestCase):
478500
def setUp(self):
479501
torch.manual_seed(123)
480502

503+
@unittest.skipIf(_DEVICE == "cpu", "Need GPU available")
504+
@unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers")
481505
@common_utils.parametrize("b", [2, 3, 4, 8])
482506
@common_utils.parametrize(
483507
"model_dtype", [torch.float16, torch.float32, torch.bfloat16]
@@ -489,7 +513,8 @@ def test_int8_dynamic_activation_intx_e2e(
489513
model_dtype: torch.dtype = torch.float32,
490514
group_size: int = 32,
491515
):
492-
model = M(embedding=False, bias=True).to(_DEVICE, dtype=model_dtype)
516+
config = MConfig(embedding=False, bias=True)
517+
model = PreTrainedM(config).to(_DEVICE, dtype=model_dtype)
493518
x = model.example_inputs(device=_DEVICE).to(model_dtype)
494519

495520
# reference model using native quantization
@@ -520,7 +545,6 @@ def test_int8_dynamic_activation_intx_e2e(
520545

521546
attach_hf_config = False
522547
if TRANSFORMERS_AVAIL:
523-
model.config = PretrainedConfig() # pretend this is a HF model
524548
attach_hf_config = _is_hf_model(model)
525549
self.assertTrue(attach_hf_config)
526550

@@ -543,10 +567,11 @@ def test_int8_dynamic_activation_intx_e2e(
543567

544568

545569
class TestTorchAoConfigIntegration(common_utils.TestCase):
570+
@unittest.skipIf(torch.backends.mps.is_available(), "MPS not supported")
546571
@unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers")
547572
def test_tied_weights_quantization(self, b: int = 4):
548-
model = M(m=128, n=128, tied_weights=True).to(_DEVICE)
549-
model.config = PretrainedConfig() # pretend this is a HF model
573+
config = MConfig(m=128, n=128, tied_weights=True)
574+
model = PreTrainedM(config).to(_DEVICE)
550575

551576
quantizer = StretchedUnifTorchaoQuantizer(b)
552577
linear_config = StretchedIntxWeightConfig(
@@ -567,27 +592,20 @@ def test_tied_weights_quantization(self, b: int = 4):
567592
self.assertTrue(isinstance(quantization_config, TorchAoConfig))
568593
self.assertTrue(quantization_config.modules_to_not_convert == ["linear2"])
569594

570-
# Simulate transformers.PreTrainedModel.from_pretrained
571-
hf_quantizer = TorchAoHfQuantizer(
572-
quantization_config,
573-
pre_quantized=False,
574-
modules_to_not_convert=quantization_config.modules_to_not_convert,
575-
)
576-
state_dict = model.state_dict()
577-
unexpected_keys = []
578-
for n, p in state_dict.items():
579-
if hf_quantizer.check_quantized_param(model, p, n, state_dict):
580-
hf_quantizer.create_quantized_param(
581-
model, p, n, _DEVICE, state_dict, unexpected_keys
582-
)
583-
model.tie_weights()
595+
# Let HF apply quantize_ given quantization_config
596+
del model.config.quantization_config
597+
with tempfile.TemporaryDirectory() as tmp_dir:
598+
model.save_pretrained(tmp_dir, safe_serialization=False)
599+
model = PreTrainedM.from_pretrained(
600+
tmp_dir, quantization_config=quantization_config
601+
)
584602

585603
check_torchao_tensor_subclass(self, model.linear1)
586604
check_torchao_tensor_subclass(self, model.linear2, weight_only=True)
587-
check_torchao_tensor_subclass(self, model.embedding, weight_only=True)
605+
check_torchao_tensor_subclass(self, model.embed_tokens, weight_only=True)
588606

589607
self.assertTrue(
590-
model.linear2.weight.data_ptr() == model.embedding.weight.data_ptr()
608+
model.linear2.weight.data_ptr() == model.embed_tokens.weight.data_ptr()
591609
)
592610

593611

0 commit comments

Comments
 (0)