Skip to content

Commit e3a6d23

Browse files
committed
refactoring functions for huggingface integration
1 parent f35ae41 commit e3a6d23

File tree

2 files changed

+109
-45
lines changed

2 files changed

+109
-45
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,18 @@
1+
import json
12
import tempfile
23
import unittest
34

45
import torch
6+
from safetensors.torch import load_file, save_file
57
from torch.testing._internal.common_utils import (
68
TestCase,
79
run_tests,
810
)
911

1012
from torchao import quantize_
1113
from torchao.prototype.safetensors.safetensors_support import (
12-
load_tensor_state_dict,
13-
save_tensor_state_dict,
14+
reconstruct_tensor_state_dict,
15+
convert_tensor_state_dict
1416
)
1517
from torchao.quantization.granularity import PerRow
1618
from torchao.quantization.quant_api import Float8DynamicActivationFloat8WeightConfig
@@ -19,6 +21,18 @@
1921
)
2022

2123

24+
def load_data(file_path: str, device: str):
25+
loaded_tensors = load_file(file_path, device)
26+
with open(file_path, "rb") as f:
27+
import struct
28+
29+
header_size = struct.unpack("<Q", f.read(8))[0]
30+
header_bytes = f.read(header_size)
31+
header = json.loads(header_bytes)
32+
metadata = header.get("__metadata__", {})
33+
return loaded_tensors, metadata
34+
35+
2236
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2337
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
2438
class TestSafeTensors(TestCase):
@@ -32,8 +46,10 @@ def test_safetensors(self):
3246
ref_output = model(*example_inputs)
3347

3448
with tempfile.NamedTemporaryFile() as f:
35-
save_tensor_state_dict(model.state_dict(), f.name)
36-
reconstructed_dict = load_tensor_state_dict(f.name, device="cuda")
49+
tensors_dict, metadata = convert_tensor_state_dict(model.state_dict())
50+
save_file(tensors_dict, f.name, metadata=metadata)
51+
tensors_dict, metadata = load_data(file_path=f.name, device="cuda")
52+
reconstructed_dict = reconstruct_tensor_state_dict(tensors_dict, metadata)
3753

3854
model = torch.nn.Sequential(
3955
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")
Lines changed: 89 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import json
22
import logging
3-
from typing import Dict
3+
from typing import Any, Dict
44

55
import torch
6-
from safetensors.torch import load_file, save_file
76

87
from torchao.prototype.safetensors.safetensors_serialization import (
98
Float8TensorAttributeJSONEncoder,
@@ -14,17 +13,51 @@
1413
logger: logging.Logger = logging.getLogger(__name__)
1514

1615

17-
def load_tensor_state_dict(file_path: str, device: str):
16+
def reconstruct_tensor_state_dict(
17+
tensors_data: Dict[str, Any],
18+
metadata: Dict[str, Any],
19+
):
1820
"""
19-
Load a dictionary of tensor subclasses from a safetensors file.
21+
Recover tensor subclass state dict from provided torch.Tensor data and metadata
22+
This function is used after loading in previously saved model state dict (using safetensors.save_file) to reconstruct tensor subclass structure
23+
24+
For example, given a previously converted tensors_data and metadata:
25+
tensors_data = {
26+
'0.weight:qdata': torch.Tensor(...),
27+
'0.weight:scale': torch.Tensor(...),
28+
'0.bias:_data': torch.Tensor(...),
29+
}
30+
metadata = {
31+
'0.weight': {
32+
'_type': 'Float8Tensor',
33+
'_data': {
34+
'block_size': [1,32],
35+
...
36+
}
37+
}
38+
'0.bias': {
39+
'_type': 'torch.Tensor',
40+
}
41+
'tensor_names': ['0.weight', '0.bias']
42+
}
43+
44+
We recover the structure of the original state dict:
45+
tensor_dict = {
46+
'0.weight': Float8Tensor(
47+
qdata=torch.Tensor(...),
48+
scale=torch.Tensor(...),
49+
block_size=[1,32],
50+
...),
51+
'0.bias': torch.Tensor(...),
52+
}
2053
2154
For torch.Tensors, we load:
2255
- _data: the tensor data
2356
- _type: the tensor type
2457
2558
For Float8Tensor, we load:
2659
- tensor_data: qdata and scale
27-
- tensor_attributes:
60+
- tensor_attributes (metadata):
2861
- block_size
2962
- mm_config
3063
- hp_value_lb
@@ -34,20 +67,13 @@ def load_tensor_state_dict(file_path: str, device: str):
3467
- dtype
3568
3669
Args:
37-
file_path: Path to the safetensors file
70+
tensors_data: Tensor data,
71+
metadata: Tensor attributes
3872
3973
Returns:
4074
Dictionary of reconstructed tensor subclasses
4175
"""
42-
loaded_tensors = load_file(file_path, device)
43-
44-
with open(file_path, "rb") as f:
45-
import struct
46-
47-
header_size = struct.unpack("<Q", f.read(8))[0]
48-
header_bytes = f.read(header_size)
49-
header = json.loads(header_bytes)
50-
metadata = header.get("__metadata__", {})
76+
combined_data = {**tensors_data, **metadata}
5177

5278
if "tensor_names" not in metadata:
5379
raise ValueError("No tensors found")
@@ -57,7 +83,7 @@ def load_tensor_state_dict(file_path: str, device: str):
5783

5884
for tensor_name in tensor_names:
5985
tensor_tensors = {}
60-
for key, value in loaded_tensors.items():
86+
for key, value in combined_data.items():
6187
if key.startswith(f"{tensor_name}:"):
6288
# Remove the prefix
6389
tensor_tensors[key[len(tensor_name) + 1 :]] = value
@@ -73,18 +99,45 @@ def load_tensor_state_dict(file_path: str, device: str):
7399
else:
74100
raise ValueError(f"Unsupported tensor type: {tensor_type}")
75101

76-
logger.info(
77-
f"Loaded {len(tensor_names)} tensor subclasses from {file_path} with metadata"
78-
)
79102
return result
80103

81104

82-
def save_tensor_state_dict(
83-
tensor_dict: Dict[str, Dict[str, torch.Tensor]],
84-
file_path: str,
105+
def convert_tensor_state_dict(
106+
tensors_dict: Dict[str, Dict[str, torch.Tensor]],
85107
):
86108
"""
87-
Save a dictionary of tensor subclasses with appropriate metadata.
109+
Convert a dictionary of tensor subclasses so that it is compatible with safetensors.save_file
110+
We disconstruct tensor subclass structure into torch.Tensor data and metadata
111+
112+
For example, given something like:
113+
tensor_dict = {
114+
'0.weight': Float8Tensor(
115+
qdata=torch.Tensor(...),
116+
scale=torch.Tensor(...),
117+
block_size=[1,32],
118+
...),
119+
'0.bias': torch.Tensor(...),
120+
}
121+
122+
We convert this to:
123+
tensors_data = {
124+
'0.weight:qdata': torch.Tensor(...),
125+
'0.weight:scale': torch.Tensor(...),
126+
'0.bias:_data': torch.Tensor(...),
127+
}
128+
metadata = {
129+
'0.weight': {
130+
'_type': 'Float8Tensor',
131+
'_data': {
132+
'block_size': [1,32],
133+
...
134+
}
135+
}
136+
'0.bias': {
137+
'_type': 'torch.Tensor',
138+
}
139+
'tensor_names': ['0.weight', '0.bias']
140+
}
88141
89142
For torch.Tensors, we save:
90143
- _data: the tensor data
@@ -105,22 +158,21 @@ def save_tensor_state_dict(
105158
106159
Args:
107160
tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
108-
file_path: Path where to save the tensors
109161
"""
110162

111-
combined_metadata = {}
112-
combined_tensors_dict = {}
163+
metadata = {}
164+
tensors_data = {}
113165

114-
for tensor_name, tensor in tensor_dict.items():
166+
for tensor_name, tensor in tensors_dict.items():
115167
if isinstance(tensor, Float8Tensor):
116-
tensors_dict = {}
168+
tensor_dict = {}
117169
for tensor_data_name in tensor.tensor_data_names:
118-
tensors_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
170+
tensor_dict[tensor_data_name] = getattr(tensor, tensor_data_name)
119171

120-
metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
172+
tensor_metadata = json.dumps(tensor, cls=Float8TensorAttributeJSONEncoder)
121173
elif type(tensor) is torch.Tensor:
122-
tensors_dict = {"_data": tensor}
123-
metadata = json.dumps({"_type": torch.Tensor.__name__})
174+
tensor_dict = {"_data": tensor}
175+
tensor_metadata = json.dumps({"_type": torch.Tensor.__name__})
124176
else:
125177
raise ValueError(f"Unsupported tensor type: {type(tensor)}")
126178

@@ -129,15 +181,11 @@ def save_tensor_state_dict(
129181
f"{tensor_name}:{key}": (
130182
value.detach().clone() if isinstance(value, torch.Tensor) else value
131183
)
132-
for key, value in tensors_dict.items()
184+
for key, value in tensor_dict.items()
133185
}
134186

135-
combined_metadata[tensor_name] = metadata
136-
combined_tensors_dict.update(prefixed_tensors_dict)
137-
138-
combined_metadata["tensor_names"] = json.dumps(list(tensor_dict.keys()))
187+
metadata[tensor_name] = tensor_metadata
188+
tensors_data.update(prefixed_tensors_dict)
139189

140-
save_file(combined_tensors_dict, file_path, metadata=combined_metadata)
141-
logger.info(
142-
f"Saved {len(tensor_dict)} tensor subclasses to {file_path} with metadata"
143-
)
190+
metadata["tensor_names"] = json.dumps(list(tensors_dict.keys()))
191+
return tensors_data, metadata

0 commit comments

Comments
 (0)