Skip to content

Commit 8634b2e

Browse files
committed
Add TestTorchAoConfigIntegration
1 parent 0fc89a4 commit 8634b2e

File tree

2 files changed

+68
-9
lines changed

2 files changed

+68
-9
lines changed

test/prototype/test_parq.py

Lines changed: 66 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,13 @@
2727
UnifQuantizer,
2828
UnifTorchaoQuantizer,
2929
)
30-
from torchao.prototype.parq.quant.config_torchao import TRANSFORMERS_AVAIL, _is_hf_model
30+
from torchao.prototype.parq.quant.config_torchao import (
31+
TRANSFORMERS_AVAIL,
32+
_attach_hf_quantization_config,
33+
_is_hf_model,
34+
)
3135
from torchao.prototype.parq.quant.uniform_torchao import _BIT_WIDTH_TO_DTYPE
32-
from torchao.quantization.granularity import PerGroup
36+
from torchao.quantization.granularity import PerAxis, PerGroup
3337
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
3438
from torchao.quantization.quant_api import (
3539
Int4WeightOnlyConfig,
@@ -49,6 +53,10 @@
4953

5054
_DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
5155

56+
if TRANSFORMERS_AVAIL:
57+
from transformers import PretrainedConfig, TorchAoConfig
58+
from transformers.quantizers.quantizer_torchao import TorchAoHfQuantizer
59+
5260

5361
def split_param_groups(model) -> tuple[list, list, list]:
5462
params_quant, params_embed, params_no_quant = [], [], []
@@ -206,9 +214,12 @@ def __init__(
206214

207215
if embedding and tied_weights:
208216
assert self.embedding.weight.shape == self.linear2.weight.shape
209-
self.linear2.weight = self.embedding.weight
217+
self.tie_weights()
210218
self._tied_weights_keys.append("linear2.weight")
211219

220+
def tie_weights(self):
221+
self.linear2.weight = self.embedding.weight
222+
212223
def reset_parameters(self):
213224
for module in (self.linear1, self.linear2):
214225
nn.init.xavier_uniform_(module.weight)
@@ -223,6 +234,9 @@ def example_inputs(self, device=None):
223234
inputs = torch.randint(1, k, (1, self.linear1.in_features), device=device)
224235
return inputs
225236

237+
def get_input_embeddings(self) -> nn.Module:
238+
return self.embedding
239+
226240
def forward(self, x):
227241
x = self.embedding(x)
228242
x = self.relu(self.linear1(x))
@@ -506,8 +520,6 @@ def test_int8_dynamic_activation_intx_e2e(
506520

507521
attach_hf_config = False
508522
if TRANSFORMERS_AVAIL:
509-
from transformers import PretrainedConfig
510-
511523
model.config = PretrainedConfig() # pretend this is a HF model
512524
attach_hf_config = _is_hf_model(model)
513525
self.assertTrue(attach_hf_config)
@@ -530,6 +542,55 @@ def test_int8_dynamic_activation_intx_e2e(
530542
self.assertTrue(isinstance(torchao_config, config.__class__))
531543

532544

545+
class TestTorchAoConfigIntegration(common_utils.TestCase):
546+
@unittest.skipIf(not TRANSFORMERS_AVAIL, "Need transformers")
547+
def test_tied_weights_quantization(self, b: int = 4):
548+
model = M(m=128, n=128, tied_weights=True).to(_DEVICE)
549+
model.config = PretrainedConfig() # pretend this is a HF model
550+
551+
quantizer = StretchedUnifTorchaoQuantizer(b)
552+
linear_config = StretchedIntxWeightConfig(
553+
b=b,
554+
quant_min=quantizer.quant_min,
555+
quant_max=quantizer.quant_max,
556+
granularity=PerAxis(0),
557+
)
558+
embed_config = IntxWeightOnlyConfig(
559+
weight_dtype=_BIT_WIDTH_TO_DTYPE[b], granularity=PerGroup(32)
560+
)
561+
module_to_config = {"_default": linear_config}
562+
configs = [embed_config]
563+
filter_fns = [lambda m: isinstance(m, nn.Embedding)]
564+
_attach_hf_quantization_config(model, filter_fns, configs, module_to_config)
565+
566+
quantization_config = getattr(model.config, "quantization_config", None)
567+
self.assertTrue(isinstance(quantization_config, TorchAoConfig))
568+
self.assertTrue(quantization_config.modules_to_not_convert == ["linear2"])
569+
570+
# Simulate transformers.PreTrainedModel.from_pretrained
571+
hf_quantizer = TorchAoHfQuantizer(
572+
quantization_config,
573+
pre_quantized=False,
574+
modules_to_not_convert=quantization_config.modules_to_not_convert,
575+
)
576+
state_dict = model.state_dict()
577+
unexpected_keys = []
578+
for n, p in state_dict.items():
579+
if hf_quantizer.check_quantized_param(model, p, n, state_dict):
580+
hf_quantizer.create_quantized_param(
581+
model, p, n, _DEVICE, state_dict, unexpected_keys
582+
)
583+
model.tie_weights()
584+
585+
check_torchao_tensor_subclass(self, model.linear1)
586+
check_torchao_tensor_subclass(self, model.linear2, weight_only=True)
587+
check_torchao_tensor_subclass(self, model.embedding, weight_only=True)
588+
589+
self.assertTrue(
590+
model.linear2.weight.data_ptr() == model.embedding.weight.data_ptr()
591+
)
592+
593+
533594
common_utils.instantiate_parametrized_tests(TestPARQuantization)
534595
common_utils.instantiate_parametrized_tests(TestUnifTorchaoQuantizer)
535596
common_utils.instantiate_parametrized_tests(TestInt8DynamicActivationTorchaoQuantizer)

torchao/prototype/parq/quant/config_torchao.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,18 +187,16 @@ def _attach_hf_quantization_config(
187187
if module_to_config is None:
188188
module_to_config = {}
189189

190-
seen_data_ptrs = set()
190+
tied_weights_keys = set(getattr(model, "_tied_weights_keys", []))
191191
modules_to_not_convert = []
192192
for name, module in model.named_modules():
193193
if not hasattr(module, "weight"):
194194
continue
195195

196196
# Do not quantize pointers to tied weights or normalization layers
197-
data_ptr = module.weight.data_ptr()
198-
if data_ptr in seen_data_ptrs or name.endswith("norm"):
197+
if f"{name}.weight" in tied_weights_keys or "norm" in name:
199198
modules_to_not_convert.append(name)
200199
continue
201-
seen_data_ptrs.add(data_ptr)
202200

203201
for i, filter_fn in enumerate(filter_fns):
204202
if filter_fn(module):

0 commit comments

Comments
 (0)