|
56 | 56 | from torchao.quantization.weight_only import (
|
57 | 57 | WeightOnlyInt8QuantLinear
|
58 | 58 | )
|
59 |
| - |
| 59 | +import os |
60 | 60 |
|
61 | 61 | torch.manual_seed(0)
|
62 | 62 |
|
@@ -932,6 +932,63 @@ def test_weight_only_quant_use_mixed_mm(self):
|
932 | 932 | sqnr = compute_error(y_ref, y_wo)
|
933 | 933 | self.assertGreater(sqnr, 43.0)
|
934 | 934 |
|
| 935 | +class TestSaveLoadMeta(unittest.TestCase): |
| 936 | + @torch.no_grad() |
| 937 | + def _test_handle_save_load_meta_impl(self, api): |
| 938 | + m, k, n = 32, 64, 32 |
| 939 | + class test_model(nn.Module): |
| 940 | + def __init__(self): |
| 941 | + super().__init__() |
| 942 | + self.lin1 = nn.Linear(k, n) |
| 943 | + self.relu = nn.ReLU() |
| 944 | + self.lin2 = nn.Linear(n, n) |
| 945 | + |
| 946 | + def forward(self, x): |
| 947 | + x = self.lin1(x) |
| 948 | + x = self.relu(x) |
| 949 | + x = self.lin2(x) |
| 950 | + return x |
| 951 | + |
| 952 | + x = torch.randn(m, k, dtype=torch.bfloat16, device="cuda") |
| 953 | + |
| 954 | + # get float reference |
| 955 | + model = test_model().to(torch.bfloat16).cuda().eval() |
| 956 | + ref_f = model(x) |
| 957 | + |
| 958 | + # save quantized state_dict |
| 959 | + api(model) |
| 960 | + torch.save(model.state_dict(), "test.pth") |
| 961 | + # get quantized reference |
| 962 | + model_qc = torch.compile(model, mode="max-autotune") |
| 963 | + ref_q = model_qc(x).detach() |
| 964 | + |
| 965 | + assert SQNR(ref_f, ref_q) > 35 |
| 966 | + |
| 967 | + # load model structure |
| 968 | + with torch.device('meta'): |
| 969 | + model = test_model() |
| 970 | + api(model) |
| 971 | + |
| 972 | + # load quantized state_dict |
| 973 | + state_dict = torch.load("test.pth", mmap=True) |
| 974 | + os.remove("test.pth") |
| 975 | + model.load_state_dict(state_dict, assign=True) |
| 976 | + model = model.to(torch.bfloat16).cuda().eval() |
| 977 | + |
| 978 | + # get quantized reference |
| 979 | + model_qc = torch.compile(model, mode="max-autotune") |
| 980 | + test = model_qc(x).detach() |
| 981 | + |
| 982 | + assert SQNR(ref_f, test) > 35 |
| 983 | + self.assertTrue(torch.equal(ref_q, test)) |
| 984 | + |
| 985 | + @torch.no_grad() |
| 986 | + def test_save_load_dqtensors(self): |
| 987 | + self._test_handle_save_load_meta_impl(change_linear_weights_to_dqtensors) |
| 988 | + |
| 989 | + @torch.no_grad() |
| 990 | + def test_save_load_woqtensors(self): |
| 991 | + self._test_handle_save_load_meta_impl(change_linear_weights_to_woqtensors) |
935 | 992 |
|
936 | 993 | class TorchCompileUnitTest(unittest.TestCase):
|
937 | 994 | def test_fullgraph(self):
|
|
0 commit comments