Skip to content

Commit 86d4311

Browse files
committed
Adding tests for save/load support
Summary: we are able to save a model quantized with a tensor subclass, save the state dict, then later, load model as meta tensor (i.e. only load tensor metadata not actually parameters) apply quantization api, and then load the quantized model state dict. We change the dtype of the subclass to match the dtype of the dequantized form, both to align with subclass design guidelines and to make this work Test Plan: python test/test.py Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: 84a69d8 Pull Request resolved: #12
1 parent cbf2c9e commit 86d4311

File tree

2 files changed

+59
-2
lines changed

2 files changed

+59
-2
lines changed

test/test.py

Lines changed: 58 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
from torchao.quantization.weight_only import (
5757
WeightOnlyInt8QuantLinear
5858
)
59-
59+
import os
6060

6161
torch.manual_seed(0)
6262

@@ -932,6 +932,63 @@ def test_weight_only_quant_use_mixed_mm(self):
932932
sqnr = compute_error(y_ref, y_wo)
933933
self.assertGreater(sqnr, 43.0)
934934

935+
class TestSaveLoadMeta(unittest.TestCase):
936+
@torch.no_grad()
937+
def _test_handle_save_load_meta_impl(self, api):
938+
m, k, n = 32, 64, 32
939+
class test_model(nn.Module):
940+
def __init__(self):
941+
super().__init__()
942+
self.lin1 = nn.Linear(k, n)
943+
self.relu = nn.ReLU()
944+
self.lin2 = nn.Linear(n, n)
945+
946+
def forward(self, x):
947+
x = self.lin1(x)
948+
x = self.relu(x)
949+
x = self.lin2(x)
950+
return x
951+
952+
x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda")
953+
954+
# get float reference
955+
model = test_model().to(torch.bfloat16).cuda().eval()
956+
ref_f = model(x)
957+
958+
# save quantized state_dict
959+
api(model)
960+
torch.save(model.state_dict(), "test.pth")
961+
# get quantized reference
962+
model_qc = torch.compile(model, mode="max-autotune")
963+
ref_q = model_qc(x).detach()
964+
965+
assert SQNR(ref_f, ref_q) > 35
966+
967+
# load model structure
968+
with torch.device('meta'):
969+
model = test_model()
970+
api(model)
971+
972+
# load quantized state_dict
973+
state_dict = torch.load("test.pth", mmap=True)
974+
os.remove("test.pth")
975+
model.load_state_dict(state_dict, assign=True)
976+
model = model.to(torch.bfloat16).cuda().eval()
977+
978+
# get quantized reference
979+
model_qc = torch.compile(model, mode="max-autotune")
980+
test = model_qc(x).detach()
981+
982+
assert SQNR(ref_f, test) > 35
983+
self.assertTrue(torch.equal(ref_q, test))
984+
985+
@torch.no_grad()
986+
def test_save_load_dqtensors(self):
987+
self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors)
988+
989+
@torch.no_grad()
990+
def test_save_load_woqtensors(self):
991+
self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors)
935992

936993
class TorchCompileUnitTest(unittest.TestCase):
937994
def test_fullgraph(self):

torchao/quantization/subclass.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ def __new__(cls, int_data, q_scales, transposed=False, **kwargs):
3232
# transposed/detached, instead we can just pass the int_data to the
3333
# new instance and alter the transposed flag where needed.
3434
kwargs["device"] = int_data.device
35-
kwargs["dtype"] = kwargs.get("dtype", torch.int8)
35+
kwargs["dtype"] = kwargs.get("dtype", q_scales.dtype)
3636
size = int_data.shape[::-1] if transposed else int_data.shape
3737
kwargs["layout"] = (
3838
kwargs.get("layout") if kwargs.get("layout", False) else int_data.layout

0 commit comments

Comments
 (0)