11import json
22import logging
3- from typing import Dict
3+ from typing import Any , Dict
44
55import torch
6- from safetensors .torch import load_file , save_file
76
87from torchao .prototype .safetensors .safetensors_serialization import (
98 Float8TensorAttributeJSONEncoder ,
1413logger : logging .Logger = logging .getLogger (__name__ )
1514
1615
17- def load_tensor_state_dict (file_path : str , device : str ):
16+ def reconstruct_tensor_state_dict (
17+ tensors_data : Dict [str , Any ],
18+ metadata : Dict [str , Any ],
19+ ):
1820 """
19- Load a dictionary of tensor subclasses from a safetensors file.
21+ Recover tensor subclass state dict from provided torch.Tensor data and metadata
22+ This function is used after loading in previously saved model state dict (using safetensors.save_file) to reconstruct tensor subclass structure
23+
24+ For example, given a previously converted tensors_data and metadata:
25+ tensors_data = {
26+ '0.weight:qdata': torch.Tensor(...),
27+ '0.weight:scale': torch.Tensor(...),
28+ '0.bias:_data': torch.Tensor(...),
29+ }
30+ metadata = {
31+ '0.weight': {
32+ '_type': 'Float8Tensor',
33+ '_data': {
34+ 'block_size': [1,32],
35+ ...
36+ }
37+ }
38+ '0.bias': {
39+ '_type': 'torch.Tensor',
40+ }
41+ 'tensor_names': ['0.weight', '0.bias']
42+ }
43+
44+ We recover the structure of the original state dict:
45+ tensor_dict = {
46+ '0.weight': Float8Tensor(
47+ qdata=torch.Tensor(...),
48+ scale=torch.Tensor(...),
49+ block_size=[1,32],
50+ ...),
51+ '0.bias': torch.Tensor(...),
52+ }
2053
2154 For torch.Tensors, we load:
2255 - _data: the tensor data
2356 - _type: the tensor type
2457
2558 For Float8Tensor, we load:
2659 - tensor_data: qdata and scale
27- - tensor_attributes:
60+ - tensor_attributes (metadata) :
2861 - block_size
2962 - mm_config
3063 - hp_value_lb
@@ -34,20 +67,13 @@ def load_tensor_state_dict(file_path: str, device: str):
3467 - dtype
3568
3669 Args:
37- file_path: Path to the safetensors file
70+ tensors_data: Tensor data,
71+ metadata: Tensor attributes
3872
3973 Returns:
4074 Dictionary of reconstructed tensor subclasses
4175 """
42- loaded_tensors = load_file (file_path , device )
43-
44- with open (file_path , "rb" ) as f :
45- import struct
46-
47- header_size = struct .unpack ("<Q" , f .read (8 ))[0 ]
48- header_bytes = f .read (header_size )
49- header = json .loads (header_bytes )
50- metadata = header .get ("__metadata__" , {})
76+ combined_data = {** tensors_data , ** metadata }
5177
5278 if "tensor_names" not in metadata :
5379 raise ValueError ("No tensors found" )
@@ -57,7 +83,7 @@ def load_tensor_state_dict(file_path: str, device: str):
5783
5884 for tensor_name in tensor_names :
5985 tensor_tensors = {}
60- for key , value in loaded_tensors .items ():
86+ for key , value in combined_data .items ():
6187 if key .startswith (f"{ tensor_name } :" ):
6288 # Remove the prefix
6389 tensor_tensors [key [len (tensor_name ) + 1 :]] = value
@@ -73,18 +99,45 @@ def load_tensor_state_dict(file_path: str, device: str):
7399 else :
74100 raise ValueError (f"Unsupported tensor type: { tensor_type } " )
75101
76- logger .info (
77- f"Loaded { len (tensor_names )} tensor subclasses from { file_path } with metadata"
78- )
79102 return result
80103
81104
82- def save_tensor_state_dict (
83- tensor_dict : Dict [str , Dict [str , torch .Tensor ]],
84- file_path : str ,
105+ def convert_tensor_state_dict (
106+ tensors_dict : Dict [str , Dict [str , torch .Tensor ]],
85107):
86108 """
87- Save a dictionary of tensor subclasses with appropriate metadata.
109+ Convert a dictionary of tensor subclasses so that it is compatible with safetensors.save_file
110+ We disconstruct tensor subclass structure into torch.Tensor data and metadata
111+
112+ For example, given something like:
113+ tensor_dict = {
114+ '0.weight': Float8Tensor(
115+ qdata=torch.Tensor(...),
116+ scale=torch.Tensor(...),
117+ block_size=[1,32],
118+ ...),
119+ '0.bias': torch.Tensor(...),
120+ }
121+
122+ We convert this to:
123+ tensors_data = {
124+ '0.weight:qdata': torch.Tensor(...),
125+ '0.weight:scale': torch.Tensor(...),
126+ '0.bias:_data': torch.Tensor(...),
127+ }
128+ metadata = {
129+ '0.weight': {
130+ '_type': 'Float8Tensor',
131+ '_data': {
132+ 'block_size': [1,32],
133+ ...
134+ }
135+ }
136+ '0.bias': {
137+ '_type': 'torch.Tensor',
138+ }
139+ 'tensor_names': ['0.weight', '0.bias']
140+ }
88141
89142 For torch.Tensors, we save:
90143 - _data: the tensor data
@@ -105,22 +158,21 @@ def save_tensor_state_dict(
105158
106159 Args:
107160 tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
108- file_path: Path where to save the tensors
109161 """
110162
111- combined_metadata = {}
112- combined_tensors_dict = {}
163+ metadata = {}
164+ tensors_data = {}
113165
114- for tensor_name , tensor in tensor_dict .items ():
166+ for tensor_name , tensor in tensors_dict .items ():
115167 if isinstance (tensor , Float8Tensor ):
116- tensors_dict = {}
168+ tensor_dict = {}
117169 for tensor_data_name in tensor .tensor_data_names :
118- tensors_dict [tensor_data_name ] = getattr (tensor , tensor_data_name )
170+ tensor_dict [tensor_data_name ] = getattr (tensor , tensor_data_name )
119171
120- metadata = json .dumps (tensor , cls = Float8TensorAttributeJSONEncoder )
172+ tensor_metadata = json .dumps (tensor , cls = Float8TensorAttributeJSONEncoder )
121173 elif type (tensor ) is torch .Tensor :
122- tensors_dict = {"_data" : tensor }
123- metadata = json .dumps ({"_type" : torch .Tensor .__name__ })
174+ tensor_dict = {"_data" : tensor }
175+ tensor_metadata = json .dumps ({"_type" : torch .Tensor .__name__ })
124176 else :
125177 raise ValueError (f"Unsupported tensor type: { type (tensor )} " )
126178
@@ -129,15 +181,11 @@ def save_tensor_state_dict(
129181 f"{ tensor_name } :{ key } " : (
130182 value .detach ().clone () if isinstance (value , torch .Tensor ) else value
131183 )
132- for key , value in tensors_dict .items ()
184+ for key , value in tensor_dict .items ()
133185 }
134186
135- combined_metadata [tensor_name ] = metadata
136- combined_tensors_dict .update (prefixed_tensors_dict )
137-
138- combined_metadata ["tensor_names" ] = json .dumps (list (tensor_dict .keys ()))
187+ metadata [tensor_name ] = tensor_metadata
188+ tensors_data .update (prefixed_tensors_dict )
139189
140- save_file (combined_tensors_dict , file_path , metadata = combined_metadata )
141- logger .info (
142- f"Saved { len (tensor_dict )} tensor subclasses to { file_path } with metadata"
143- )
190+ metadata ["tensor_names" ] = json .dumps (list (tensors_dict .keys ()))
191+ return tensors_data , metadata
0 commit comments