Skip to content

Commit cb0854f

Browse files
committed
refactoring functions for huggingface integration
1 parent f35ae41 commit cb0854f

File tree

2 files changed

+28
-26
lines changed

2 files changed

+28
-26
lines changed

test/prototype/safetensors/test_safetensors_support.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
1+
import json
12
import tempfile
23
import unittest
34

45
import torch
6+
from safetensors.torch import load_file, save_file
57
from torch.testing._internal.common_utils import (
68
TestCase,
79
run_tests,
@@ -19,6 +21,18 @@
1921
)
2022

2123

24+
def load_data(file_path: str, device: str):
25+
loaded_tensors = load_file(file_path, device)
26+
with open(file_path, "rb") as f:
27+
import struct
28+
29+
header_size = struct.unpack("<Q", f.read(8))[0]
30+
header_bytes = f.read(header_size)
31+
header = json.loads(header_bytes)
32+
metadata = header.get("__metadata__", {})
33+
return loaded_tensors, metadata
34+
35+
2236
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
2337
@unittest.skipIf(not is_sm_at_least_89(), "Need sm89+")
2438
class TestSafeTensors(TestCase):
@@ -32,8 +46,10 @@ def test_safetensors(self):
3246
ref_output = model(*example_inputs)
3347

3448
with tempfile.NamedTemporaryFile() as f:
35-
save_tensor_state_dict(model.state_dict(), f.name)
36-
reconstructed_dict = load_tensor_state_dict(f.name, device="cuda")
49+
tensors_dict, metadata = save_tensor_state_dict(model.state_dict())
50+
save_file(tensors_dict, f.name, metadata=metadata)
51+
tensor_dict, metadata = load_data(file_path=f.name, device="cuda")
52+
reconstructed_dict = load_tensor_state_dict(tensor_dict, metadata)
3753

3854
model = torch.nn.Sequential(
3955
torch.nn.Linear(32, 256, dtype=torch.bfloat16, device="cuda")

torchao/prototype/safetensors/safetensors_support.py

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
import json
22
import logging
3-
from typing import Dict
3+
from typing import Any, Dict
44

55
import torch
6-
from safetensors.torch import load_file, save_file
76

87
from torchao.prototype.safetensors.safetensors_serialization import (
98
Float8TensorAttributeJSONEncoder,
@@ -14,7 +13,10 @@
1413
logger: logging.Logger = logging.getLogger(__name__)
1514

1615

17-
def load_tensor_state_dict(file_path: str, device: str):
16+
def load_tensor_state_dict(
17+
tensor_data: Dict[str, Any],
18+
metadata: Dict[str, Any],
19+
):
1820
"""
1921
Load a dictionary of tensor subclasses from a safetensors file.
2022
@@ -34,20 +36,13 @@ def load_tensor_state_dict(file_path: str, device: str):
3436
- dtype
3537
3638
Args:
37-
file_path: Path to the safetensors file
39+
tensor_data: Tensor data,
40+
metadata: Tensor attributes,
3841
3942
Returns:
4043
Dictionary of reconstructed tensor subclasses
4144
"""
42-
loaded_tensors = load_file(file_path, device)
43-
44-
with open(file_path, "rb") as f:
45-
import struct
46-
47-
header_size = struct.unpack("<Q", f.read(8))[0]
48-
header_bytes = f.read(header_size)
49-
header = json.loads(header_bytes)
50-
metadata = header.get("__metadata__", {})
45+
combined_data = {**tensor_data, **metadata}
5146

5247
if "tensor_names" not in metadata:
5348
raise ValueError("No tensors found")
@@ -57,7 +52,7 @@ def load_tensor_state_dict(file_path: str, device: str):
5752

5853
for tensor_name in tensor_names:
5954
tensor_tensors = {}
60-
for key, value in loaded_tensors.items():
55+
for key, value in combined_data.items():
6156
if key.startswith(f"{tensor_name}:"):
6257
# Remove the prefix
6358
tensor_tensors[key[len(tensor_name) + 1 :]] = value
@@ -73,15 +68,11 @@ def load_tensor_state_dict(file_path: str, device: str):
7368
else:
7469
raise ValueError(f"Unsupported tensor type: {tensor_type}")
7570

76-
logger.info(
77-
f"Loaded {len(tensor_names)} tensor subclasses from {file_path} with metadata"
78-
)
7971
return result
8072

8173

8274
def save_tensor_state_dict(
8375
tensor_dict: Dict[str, Dict[str, torch.Tensor]],
84-
file_path: str,
8576
):
8677
"""
8778
Save a dictionary of tensor subclasses with appropriate metadata.
@@ -105,7 +96,6 @@ def save_tensor_state_dict(
10596
10697
Args:
10798
tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
108-
file_path: Path where to save the tensors
10999
"""
110100

111101
combined_metadata = {}
@@ -136,8 +126,4 @@ def save_tensor_state_dict(
136126
combined_tensors_dict.update(prefixed_tensors_dict)
137127

138128
combined_metadata["tensor_names"] = json.dumps(list(tensor_dict.keys()))
139-
140-
save_file(combined_tensors_dict, file_path, metadata=combined_metadata)
141-
logger.info(
142-
f"Saved {len(tensor_dict)} tensor subclasses to {file_path} with metadata"
143-
)
129+
return combined_tensors_dict, combined_metadata

0 commit comments

Comments
 (0)