Skip to content

Commit 60bb577

Browse files
committed
refactoring functions for huggingface integration
1 parent f35ae41 commit 60bb577

File tree

2 files changed

+117
-79
lines changed

2 files changed

+117
-79
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1+
import json
12
import tempfile
23
import unittest
34

45
import torch
6+
from safetensors.torch import load_file, save_file
57
from torch.testing._internal.common_utils import (
68
TestCase,
79
run_tests,
810
)
911

1012
from torchao import quantize_
1113
from torchao.prototype.safetensors.safetensors_support import (
12-
load_tensor_state_dict,
13-
save_tensor_state_dict,
14+
flatten_tensor_state_dict,
15+
unflatten_tensor_state_dict,
1416
)
1517
from torchao.quantization.granularity import PerRow
1618
from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig
@@ -19,6 +21,18 @@
1921
)
2022

2123

24+
def load_data(file_path: str, device: str):
25+
loaded_tensors = load_file(file_path, device)
26+
with open(file_path, "rb") as f:
27+
import struct
28+
29+
header_size = struct.unpack("<Q", f.read(8))[0]
30+
header_bytes = f.read(header_size)
31+
header = json.loads(header_bytes)
32+
metadata = header.get("__metadata__", {})
33+
return loaded_tensors, metadata
34+
35+
2236
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2337
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
2438
class TestSafeTensors(TestCase):
@@ -32,8 +46,10 @@ def test_safetensors(self):
3246
ref_output = model(*example_inputs)
3347

3448
with tempfile.NamedTemporaryFile() as f:
35-
save_tensor_state_dict(model.state_dict(), f.name)
36-
reconstructed_dict = load_tensor_state_dict(f.name, device="cuda")
49+
tensors_dict, metadata = flatten_tensor_state_dict(model.state_dict())
50+
save_file(tensors_dict, f.name, metadata=metadata)
51+
tensors_data_dict, metadata_dict = load_data(file_path=f.name, device="cuda")
52+
reconstructed_dict = unflatten_tensor_state_dict(tensors_data_dict, metadata_dict)
3753

3854
model = torch.nn.Sequential(
3955
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
Lines changed: 97 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import json
22
import logging
3-
from typing import Dict
3+
from typing import Any, Dict
44

55
import torch
6-
from safetensors.torch import load_file, save_file
76

87
from torchao.prototype.safetensors.safetensors_serialization import (
98
Float8TensorAttributeJSONEncoder,
@@ -14,55 +13,67 @@
1413
logger: logging.Logger = logging.getLogger(__name__)
1514

1615

17-
def load_tensor_state_dict(file_path: str, device: str):
16+
def unflatten_tensor_state_dict(
17+
tensors_data_dict: Dict[str, Any],
18+
metadata_dict: Dict[str, Any],
19+
):
1820
"""
19-
Load a dictionary of tensor subclasses from a safetensors file.
20-
21-
For torch.Tensors, we load:
22-
- _data: the tensor data
23-
- _type: the tensor type
24-
25-
For Float8Tensor, we load:
26-
- tensor_data: qdata and scale
27-
- tensor_attributes:
28-
- block_size
29-
- mm_config
30-
- hp_value_lb
31-
- hp_value_ub
32-
- act_quant_kwargs
33-
- kernel_preference
34-
- dtype
21+
Reconstructs tensor subclass state dict from provided torch.Tensor data and metadata
22+
This function is used after loading in previously saved model state dict (using safetensors.save_file) to reconstruct tensor subclass structure
23+
24+
For example, given a previously flattened tensors_data_dict and metadata_dict:
25+
tensors_data_dict = {
26+
'0.weight:qdata': torch.Tensor(...),
27+
'0.weight:scale': torch.Tensor(...),
28+
'0.bias:_data': torch.Tensor(...),
29+
}
30+
metadata_dict = {
31+
'0.weight': {
32+
'_type': 'Float8Tensor',
33+
'_data': {
34+
'block_size': [1,32],
35+
...
36+
}
37+
}
38+
'0.bias': {
39+
'_type': 'torch.Tensor',
40+
}
41+
'tensor_names': ['0.weight', '0.bias']
42+
}
43+
44+
We recover the structure of the original state dict:
45+
tensor_dict = {
46+
'0.weight': Float8Tensor(
47+
qdata=torch.Tensor(...),
48+
scale=torch.Tensor(...),
49+
block_size=[1,32],
50+
...),
51+
'0.bias': torch.Tensor(...),
52+
}
3553
3654
Args:
37-
file_path: Path to the safetensors file
55+
tensors_data_dict: a dictionary from "tensor_name:tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance
56+
metadata_dict: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance
3857
3958
Returns:
4059
Dictionary of reconstructed tensor subclasses
4160
"""
42-
loaded_tensors = load_file(file_path, device)
43-
44-
with open(file_path, "rb") as f:
45-
import struct
46-
47-
header_size = struct.unpack("<Q", f.read(8))[0]
48-
header_bytes = f.read(header_size)
49-
header = json.loads(header_bytes)
50-
metadata = header.get("__metadata__", {})
61+
combined_data = {**tensors_data_dict, **metadata_dict}
5162

52-
if "tensor_names" not in metadata:
63+
if "tensor_names" not in metadata_dict:
5364
raise ValueError("No tensors found")
5465

55-
tensor_names = json.loads(metadata["tensor_names"])
66+
tensor_names = json.loads(metadata_dict["tensor_names"])
5667
result = {}
5768

5869
for tensor_name in tensor_names:
5970
tensor_tensors = {}
60-
for key, value in loaded_tensors.items():
71+
for key, value in combined_data.items():
6172
if key.startswith(f"{tensor_name}:"):
6273
# Remove the prefix
6374
tensor_tensors[key[len(tensor_name) + 1 :]] = value
6475

65-
tensor_metadata = json.loads(metadata.get(tensor_name))
76+
tensor_metadata = json.loads(metadata_dict.get(tensor_name))
6677
tensor_type = tensor_metadata.get("_type")
6778

6879
if tensor_type == Float8Tensor.__name__:
@@ -73,54 +84,69 @@ def load_tensor_state_dict(file_path: str, device: str):
7384
else:
7485
raise ValueError(f"Unsupported tensor type: {tensor_type}")
7586

76-
logger.info(
77-
f"Loaded {len(tensor_names)} tensor subclasses from {file_path} with metadata"
78-
)
7987
return result
8088

8189

82-
def save_tensor_state_dict(
83-
tensor_dict: Dict[str, Dict[str, torch.Tensor]],
84-
file_path: str,
90+
def flatten_tensor_state_dict(
91+
tensors_dict: Dict[str, Dict[str, torch.Tensor]],
8592
):
8693
"""
87-
Save a dictionary of tensor subclasses with appropriate metadata.
88-
89-
For torch.Tensors, we save:
90-
- _data: the tensor data
91-
- _type: the tensor type
92-
93-
For Float8Tensor, we save:
94-
- tensor_data:
95-
- qdata
96-
- scale
97-
- tensor_attributes:
98-
- block_size
99-
- mm_config
100-
- hp_value_lb
101-
- hp_value_ub
102-
- act_quant_kwargs
103-
- kernel_preference
104-
- dtype
94+
Flattens a dictionary of tensor subclasses so that it is compatible with safetensors.save_file
95+
We disconstruct tensor subclass structure into torch.Tensor data and metadata
96+
97+
For example, given something like:
98+
tensor_dict = {
99+
'0.weight': Float8Tensor(
100+
qdata=torch.Tensor(...),
101+
scale=torch.Tensor(...),
102+
block_size=[1,32],
103+
...),
104+
'0.bias': torch.Tensor(...),
105+
}
106+
107+
We flatten this to:
108+
tensors_data = {
109+
'0.weight:qdata': torch.Tensor(...),
110+
'0.weight:scale': torch.Tensor(...),
111+
'0.bias:_data': torch.Tensor(...),
112+
}
113+
metadata = {
114+
'0.weight': {
115+
'_type': 'Float8Tensor',
116+
'_data': {
117+
'block_size': [1,32],
118+
...
119+
}
120+
}
121+
'0.bias': {
122+
'_type': 'torch.Tensor',
123+
}
124+
'tensor_names': ['0.weight', '0.bias']
125+
}
105126
106127
Args:
107128
tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
108-
file_path: Path where to save the tensors
129+
130+
Returns:
131+
A tuple of (tensors_data, metadata) where
132+
tensors_data: Dict[str, torch.Tensor] contains the tensor data
133+
metadata: Dict[str, str] contains accompanying metadata from tensor subclass
134+
This structure is compatible with safetensors.save_file
109135
"""
110136

111-
combined_metadata = {}
112-
combined_tensors_dict = {}
137+
metadata = {}
138+
tensors_data = {}
113139

114-
for tensor_name, tensor in tensor_dict.items():
140+
for tensor_name, tensor in tensors_dict.items():
115141
if isinstance(tensor, Float8Tensor):
116-
tensors_dict = {}
142+
tensor_dict = {}
117143
for tensor_data_name in tensor.tensor_data_names:
118-
tensors_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
144+
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
119145

120-
metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
146+
tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
121147
elif type(tensor) is torch.Tensor:
122-
tensors_dict = {"_data": tensor}
123-
metadata = json.dumps({"_type": torch.Tensor.__name__})
148+
tensor_dict = {"_data": tensor}
149+
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
124150
else:
125151
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
126152

@@ -129,15 +155,11 @@ def save_tensor_state_dict(
129155
f"{tensor_name}:{key}": (
130156
value.detach().clone() if isinstance(value, torch.Tensor) else value
131157
)
132-
for key, value in tensors_dict.items()
158+
for key, value in tensor_dict.items()
133159
}
134160

135-
combined_metadata[tensor_name] = metadata
136-
combined_tensors_dict.update(prefixed_tensors_dict)
137-
138-
combined_metadata["tensor_names"] = json.dumps(list(tensor_dict.keys()))
161+
metadata[tensor_name] = tensor_metadata
162+
tensors_data.update(prefixed_tensors_dict)
139163

140-
save_file(combined_tensors_dict, file_path, metadata=combined_metadata)
141-
logger.info(
142-
f"Saved {len(tensor_dict)} tensor subclasses to {file_path} with metadata"
143-
)
164+
metadata["tensor_names"] = json.dumps(list(tensors_dict.keys()))
165+
return tensors_data, metadata

0 commit comments

Comments
 (0)