Skip to content

Commit 63cb06a

Browse files
committed
add int4tensor support for safetensors
ghstack-source-id: cec24fc Pull Request resolved: #3056
1 parent f92b898 commit 63cb06a

File tree

3 files changed

+25
-20
lines changed

3 files changed

+25
-20
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from torch.testing._internal.common_utils import (
88
TestCase,
99
run_tests,
10+
instantiate_parametrized_tests,
11+
parametrize,
1012
)
1113

1214
from torchao import quantize_
@@ -15,7 +17,7 @@
1517
unflatten_tensor_state_dict,
1618
)
1719
from torchao.quantization.granularity import PerRow
18-
from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig
20+
from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig, Int4WeightOnlyConfig
1921
from torchao.utils import (
2022
is_sm_at_least_89,
2123
)
@@ -36,13 +38,13 @@ def load_data(file_path: str, device: str):
3638
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
3739
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
3840
class TestSafeTensors(TestCase):
39-
def test_safetensors(self):
40-
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
41+
@parametrize("config", [Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), Int4WeightOnlyConfig()])
42+
def test_safetensors(self, config):
4143
model = torch.nn.Sequential(
42-
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
44+
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
4345
)
4446
quantize_(model, config)
45-
example_inputs = (torch.randn(2, 32, dtype=torch.bfloat16, device="cuda"),)
47+
example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),)
4648
ref_output = model(*example_inputs)
4749

4850
with tempfile.NamedTemporaryFile() as f:
@@ -54,12 +56,13 @@ def test_safetensors(self):
5456
)
5557

5658
model = torch.nn.Sequential(
57-
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
59+
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
5860
)
5961
model.load_state_dict(reconstructed_dict, assign=True)
6062
output = model(*example_inputs)
6163
assert torch.equal(output, ref_output)
6264

65+
instantiate_parametrized_tests(TestSafeTensors)
6366

6467
if __name__ == "__main__":
6568
run_tests()

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@
77
from torchao.prototype.safetensors.safetensors_utils import (
88
Float8TensorAttributeJSONEncoder,
99
object_from_dict,
10+
ALLOWED_TENSORS
1011
)
11-
from torchao.quantization import Float8Tensor
12+
from torchao.quantization import Float8Tensor, Int4Tensor
1213

1314
logger: logging.Logger = logging.getLogger(__name__)
1415

@@ -76,12 +77,11 @@ def unflatten_tensor_state_dict(
7677

7778
tensor_metadata = json.loads(metadata.get(tensor_name))
7879
tensor_type = tensor_metadata.get("_type")
79-
80-
if tensor_type == Float8Tensor.__name__:
80+
if tensor_type == torch.Tensor.__name__:
81+
result[tensor_name] = tensor_tensors["_data"]
82+
elif tensor_type in ALLOWED_TENSORS:
8183
tensor_metadata["_data"].update(tensor_tensors)
8284
result[tensor_name] = object_from_dict(tensor_metadata)
83-
elif tensor_type == torch.Tensor.__name__:
84-
result[tensor_name] = tensor_tensors["_data"]
8585
else:
8686
raise ValueError(f"Unsupported tensor type: {tensor_type}")
8787

@@ -140,15 +140,15 @@ def flatten_tensor_state_dict(
140140
tensors_data_dict = {}
141141

142142
for tensor_name, tensor in tensors_dict.items():
143-
if isinstance(tensor, Float8Tensor):
143+
if type(tensor) is torch.Tensor:
144+
tensor_dict = {"_data": tensor}
145+
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
146+
elif tensor.__class__.__name__ in ALLOWED_TENSORS:
144147
tensor_dict = {}
145148
for tensor_data_name in tensor.tensor_data_names:
146149
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
147150

148151
tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
149-
elif type(tensor) is torch.Tensor:
150-
tensor_dict = {"_data": tensor}
151-
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
152152
else:
153153
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
154154

torchao/prototype/safetensors/safetensors_utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,36 @@
66
import torch
77

88
import torchao
9-
from torchao.quantization import Float8Tensor
9+
from torchao.quantization import Float8Tensor, Int4Tensor
1010
from torchao.quantization.quantize_.common import KernelPreference
1111
from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs
1212

1313
ALLOWED_CLASSES = {
1414
"Float8Tensor": Float8Tensor,
15+
"Int4Tensor": Int4Tensor,
1516
"Float8MMConfig": torchao.float8.inference.Float8MMConfig,
1617
"QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs,
1718
"PerRow": torchao.quantization.PerRow,
1819
"PerTensor": torchao.quantization.PerTensor,
1920
"KernelPreference": KernelPreference,
2021
}
2122

22-
ALLOWED_TENSORS = ["Float8Tensor", "Tensor"]
23+
ALLOWED_TENSORS = ["Float8Tensor", "Int4Tensor", "Tensor"]
2324

2425
__all__ = [
2526
"Float8TensorAttributeJSONEncoder",
2627
"object_from_dict",
2728
"is_metadata_torchao",
2829
]
2930

30-
3131
class Float8TensorAttributeJSONEncoder(json.JSONEncoder):
3232
def default(self, o):
33-
if isinstance(o, Float8Tensor):
33+
if o.__class__.__name__ in ALLOWED_TENSORS:
3434
tensor_attr_dict = {}
35+
optional_tensor_attributes = o.optional_tensor_attribute_names if hasattr(o, "optional_tensor_attribute_names") else []
36+
3537
all_tensor_attributes = (
36-
o.optional_tensor_attribute_names + o.tensor_attribute_names
38+
optional_tensor_attributes + o.tensor_attribute_names
3739
)
3840

3941
for tensor_attribute_name in all_tensor_attributes:

0 commit comments

Comments
 (0)