Skip to content

Commit f71afd9

Browse files
committed
add int4tensor support for safetensors
ghstack-source-id: 96e3e23 Pull Request resolved: #3056
1 parent f92b898 commit f71afd9

File tree

3 files changed

+53
-22
lines changed

3 files changed

+53
-22
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
from safetensors.torch import load_file, save_file
77
from torch.testing._internal.common_utils import (
88
TestCase,
9+
instantiate_parametrized_tests,
10+
parametrize,
911
run_tests,
1012
)
1113

@@ -15,10 +17,11 @@
1517
unflatten_tensor_state_dict,
1618
)
1719
from torchao.quantization.granularity import PerRow
18-
from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig
19-
from torchao.utils import (
20-
is_sm_at_least_89,
20+
from torchao.quantization.quant_api import (
21+
Float8DynamicActivationFloat8WeightConfig,
22+
Int4WeightOnlyConfig,
2123
)
24+
from torchao.utils import is_sm_at_least_89
2225

2326

2427
def load_data(file_path: str, device: str):
@@ -36,13 +39,26 @@ def load_data(file_path: str, device: str):
3639
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
3740
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
3841
class TestSafeTensors(TestCase):
39-
def test_safetensors(self):
40-
config = Float8DynamicActivationFloat8WeightConfig(granularity=PerRow())
42+
@parametrize(
43+
"config, act_pre_scale",
44+
[
45+
(Float8DynamicActivationFloat8WeightConfig(granularity=PerRow()), None),
46+
(Int4WeightOnlyConfig(), None),
47+
(
48+
Int4WeightOnlyConfig(),
49+
torch.ones((1), dtype=torch.bfloat16),
50+
),
51+
],
52+
)
53+
def test_safetensors(self, config, act_pre_scale=None):
4154
model = torch.nn.Sequential(
42-
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
55+
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
4356
)
4457
quantize_(model, config)
45-
example_inputs = (torch.randn(2, 32, dtype=torch.bfloat16, device="cuda"),)
58+
if act_pre_scale is not None:
59+
act_pre_scale = act_pre_scale.to("cuda")
60+
model[0].weight.act_pre_scale = act_pre_scale
61+
example_inputs = (torch.randn(2, 128, dtype=torch.bfloat16, device="cuda"),)
4662
ref_output = model(*example_inputs)
4763

4864
with tempfile.NamedTemporaryFile() as f:
@@ -54,12 +70,14 @@ def test_safetensors(self):
5470
)
5571

5672
model = torch.nn.Sequential(
57-
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
73+
torch.nn.Linear(128, 256, dtype=torch.bfloat16, device="cuda")
5874
)
5975
model.load_state_dict(reconstructed_dict, assign=True)
6076
output = model(*example_inputs)
6177
assert torch.equal(output, ref_output)
6278

6379

80+
instantiate_parametrized_tests(TestSafeTensors)
81+
6482
if __name__ == "__main__":
6583
run_tests()

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@
55
import torch
66

77
from torchao.prototype.safetensors.safetensors_utils import (
8-
Float8TensorAttributeJSONEncoder,
8+
ALLOWED_TENSORS_SUBCLASSES,
9+
TensorSubclassAttributeJSONEncoder,
910
object_from_dict,
1011
)
11-
from torchao.quantization import Float8Tensor
1212

1313
logger: logging.Logger = logging.getLogger(__name__)
1414

@@ -77,7 +77,7 @@ def unflatten_tensor_state_dict(
7777
tensor_metadata = json.loads(metadata.get(tensor_name))
7878
tensor_type = tensor_metadata.get("_type")
7979

80-
if tensor_type == Float8Tensor.__name__:
80+
if tensor_type in ALLOWED_TENSORS_SUBCLASSES:
8181
tensor_metadata["_data"].update(tensor_tensors)
8282
result[tensor_name] = object_from_dict(tensor_metadata)
8383
elif tensor_type == torch.Tensor.__name__:
@@ -140,12 +140,18 @@ def flatten_tensor_state_dict(
140140
tensors_data_dict = {}
141141

142142
for tensor_name, tensor in tensors_dict.items():
143-
if isinstance(tensor, Float8Tensor):
143+
if tensor.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES:
144144
tensor_dict = {}
145-
for tensor_data_name in tensor.tensor_data_names:
146-
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
147145

148-
tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
146+
all_tensor_data = list(tensor.tensor_data_names) # create a copy
147+
if hasattr(tensor, "optional_tensor_data_names"):
148+
all_tensor_data += tensor.optional_tensor_data_names
149+
150+
for tensor_data_name in all_tensor_data:
151+
if getattr(tensor, tensor_data_name) is not None:
152+
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
153+
154+
tensor_metadata = json.dumps(tensor, cls=TensorSubclassAttributeJSONEncoder)
149155
elif type(tensor) is torch.Tensor:
150156
tensor_dict = {"_data": tensor}
151157
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})

torchao/prototype/safetensors/safetensors_utils.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,34 +6,41 @@
66
import torch
77

88
import torchao
9-
from torchao.quantization import Float8Tensor
9+
from torchao.quantization import Float8Tensor, Int4Tensor
1010
from torchao.quantization.quantize_.common import KernelPreference
1111
from torchao.quantization.quantize_.workflows import QuantizeTensorToFloat8Kwargs
1212

1313
ALLOWED_CLASSES = {
1414
"Float8Tensor": Float8Tensor,
15+
"Int4Tensor": Int4Tensor,
1516
"Float8MMConfig": torchao.float8.inference.Float8MMConfig,
1617
"QuantizeTensorToFloat8Kwargs": QuantizeTensorToFloat8Kwargs,
1718
"PerRow": torchao.quantization.PerRow,
1819
"PerTensor": torchao.quantization.PerTensor,
1920
"KernelPreference": KernelPreference,
2021
}
2122

22-
ALLOWED_TENSORS = ["Float8Tensor", "Tensor"]
23+
ALLOWED_TENSORS_SUBCLASSES = ["Float8Tensor", "Int4Tensor"]
2324

2425
__all__ = [
25-
"Float8TensorAttributeJSONEncoder",
26+
"TensorSubclassAttributeJSONEncoder",
2627
"object_from_dict",
2728
"is_metadata_torchao",
2829
]
2930

3031

31-
class Float8TensorAttributeJSONEncoder(json.JSONEncoder):
32+
class TensorSubclassAttributeJSONEncoder(json.JSONEncoder):
3233
def default(self, o):
33-
if isinstance(o, Float8Tensor):
34+
if o.__class__.__name__ in ALLOWED_TENSORS_SUBCLASSES:
3435
tensor_attr_dict = {}
36+
optional_tensor_attributes = (
37+
o.optional_tensor_attribute_names
38+
if hasattr(o, "optional_tensor_attribute_names")
39+
else []
40+
)
41+
3542
all_tensor_attributes = (
36-
o.optional_tensor_attribute_names + o.tensor_attribute_names
43+
optional_tensor_attributes + o.tensor_attribute_names
3744
)
3845

3946
for tensor_attribute_name in all_tensor_attributes:
@@ -190,7 +197,7 @@ def is_metadata_torchao(metadata: Dict[str, Any]):
190197

191198
# returns None if _type not in tensor_dict
192199
tensor_type = tensor_dict.get("_type")
193-
if tensor_type not in ALLOWED_TENSORS:
200+
if tensor_type not in ALLOWED_TENSORS_SUBCLASSES or tensor_type != "Tensor":
194201
return False
195202

196203
return True

0 commit comments

Comments
 (0)