11import json
22import logging
3- from typing import Dict
3+ from typing import Any , Dict
44
55import torch
6- from safetensors .torch import load_file , save_file
76
87from torchao .prototype .safetensors .safetensors_serialization import (
98 Float8TensorAttributeJSONEncoder ,
1413logger : logging .Logger = logging .getLogger (__name__ )
1514
1615
17- def load_tensor_state_dict (file_path : str , device : str ):
16+ def load_tensor_state_dict (
17+ tensor_data : Dict [str , Any ],
18+ metadata : Dict [str , Any ],
19+ ):
1820 """
1921 Load a dictionary of tensor subclasses from a safetensors file.
2022
@@ -34,20 +36,13 @@ def load_tensor_state_dict(file_path: str, device: str):
3436 - dtype
3537
3638 Args:
37- file_path: Path to the safetensors file
39+ tensor_data: Tensor data,
40+ metadata: Tensor attributes,
3841
3942 Returns:
4043 Dictionary of reconstructed tensor subclasses
4144 """
42- loaded_tensors = load_file (file_path , device )
43-
44- with open (file_path , "rb" ) as f :
45- import struct
46-
47- header_size = struct .unpack ("<Q" , f .read (8 ))[0 ]
48- header_bytes = f .read (header_size )
49- header = json .loads (header_bytes )
50- metadata = header .get ("__metadata__" , {})
45+ combined_data = {** tensor_data , ** metadata }
5146
5247 if "tensor_names" not in metadata :
5348 raise ValueError ("No tensors found" )
@@ -57,7 +52,7 @@ def load_tensor_state_dict(file_path: str, device: str):
5752
5853 for tensor_name in tensor_names :
5954 tensor_tensors = {}
60- for key , value in loaded_tensors .items ():
55+ for key , value in combined_data .items ():
6156 if key .startswith (f"{ tensor_name } :" ):
6257 # Remove the prefix
6358 tensor_tensors [key [len (tensor_name ) + 1 :]] = value
@@ -73,15 +68,11 @@ def load_tensor_state_dict(file_path: str, device: str):
7368 else :
7469 raise ValueError (f"Unsupported tensor type: { tensor_type } " )
7570
76- logger .info (
77- f"Loaded { len (tensor_names )} tensor subclasses from { file_path } with metadata"
78- )
7971 return result
8072
8173
8274def save_tensor_state_dict (
8375 tensor_dict : Dict [str , Dict [str , torch .Tensor ]],
84- file_path : str ,
8576):
8677 """
8778 Save a dictionary of tensor subclasses with appropriate metadata.
@@ -105,7 +96,6 @@ def save_tensor_state_dict(
10596
10697 Args:
10798 tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
108- file_path: Path where to save the tensors
10999 """
110100
111101 combined_metadata = {}
@@ -136,8 +126,4 @@ def save_tensor_state_dict(
136126 combined_tensors_dict .update (prefixed_tensors_dict )
137127
138128 combined_metadata ["tensor_names" ] = json .dumps (list (tensor_dict .keys ()))
139-
140- save_file (combined_tensors_dict , file_path , metadata = combined_metadata )
141- logger .info (
142- f"Saved { len (tensor_dict )} tensor subclasses to { file_path } with metadata"
143- )
129+ return combined_tensors_dict , combined_metadata
0 commit comments