Skip to content

Commit fd06151

Browse files
authored
add int4tilepackedto4dtensor subclass to safetensors (#3064)
1 parent d23ed9e commit fd06151

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class TestSafeTensors(TestCase):
4545
(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), False),
4646
(Int4WeightOnlyConfig(), False),
4747
(Int4WeightOnlyConfig(), True),
48+
(Int4WeightOnlyConfig(int4_packing_format="tile_packed_to_4d"), False),
4849
],
4950
)
5051
def test_safetensors(self, config, act_pre_scale=False):

torchao/prototype/safetensors/safetensors_utils.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,30 @@
66
import torch
77

88
import torchao
9-
from torchao.quantization import Float8Tensor, Int4Tensor
9+
from torchao.quantization import (
10+
Float8Tensor,
11+
Int4Tensor,
12+
Int4TilePackedTo4dTensor,
13+
)
1014
from torchao.quantization.quantize_.common import KernelPreference
1115
from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs
1216

1317
ALLOWED_CLASSES = {
1418
"Float8Tensor": Float8Tensor,
1519
"Int4Tensor": Int4Tensor,
20+
"Int4TilePackedTo4dTensor": Int4TilePackedTo4dTensor,
1621
"Float8MMConfig": torchao.float8.inference.Float8MMConfig,
1722
"QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs,
1823
"PerRow": torchao.quantization.PerRow,
1924
"PerTensor": torchao.quantization.PerTensor,
2025
"KernelPreference": KernelPreference,
2126
}
2227

23-
ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor"]
28+
ALLOWED_TENSORS_SUBCLASSES = [
29+
"Float8Tensor",
30+
"Int4Tensor",
31+
"Int4TilePackedTo4dTensor",
32+
]
2433

2534
__all__ = [
2635
"TensorSubclassAttributeJSONEncoder",

0 commit comments

Comments
 (0)