Skip to content

Commit 86d7fcf

Browse files
committed
Update on "add int4tensor support for safetensors"
**Summary** adding `Int4Tensor` support for safetensors (`Int4WeightOnlyConfig`) **Test plan** modified unit test to include `Int4WeightOnlyConfig` `python test/prototype/safetensors/test_safetensors_support.py` [ghstack-poisoned]
1 parent 6e31683 commit 86d7fcf

File tree

1 file changed

+8
-10
lines changed

1 file changed

+8
-10
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -42,22 +42,20 @@ class TestSafeTensors(TestCase):
4242
@parametrize(
4343
"config, act_pre_scale",
4444
[
45-
(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), None),
46-
(Int4WeightOnlyConfig(), None),
47-
(
48-
Int4WeightOnlyConfig(),
49-
torch.ones((1), dtype=torch.bfloat16),
50-
),
45+
(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False),
46+
(Int4WeightOnlyConfig(), False),
47+
(Int4WeightOnlyConfig(), True),
5148
],
5249
)
53-
def test_safetensors(self, config, act_pre_scale=None):
50+
def test_safetensors(self, config, act_pre_scale=False):
5451
model = torch.nn.Sequential(
5552
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
5653
)
5754
quantize_(model, config)
58-
if act_pre_scale is not None:
59-
act_pre_scale = act_pre_scale.to("cuda")
60-
model[0].weight.act_pre_scale = act_pre_scale
55+
if act_pre_scale:
56+
model[0].weight.act_pre_scale = torch.ones(
57+
(1), dtype=torch.bfloat16, device="cuda"
58+
)
6159
example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),)
6260
ref_output = model(*example_inputs)
6361

0 commit comments

Comments
 (0)