11import json
22import logging
3- from typing import Dict
3+ from typing import Any , Dict
44
55import torch
6- from safetensors .torch import load_file , save_file
76
87from torchao .prototype .safetensors .safetensors_serialization import (
98 Float8TensorAttributeJSONEncoder ,
1413logger : logging .Logger = logging .getLogger (__name__ )
1514
1615
17- def load_tensor_state_dict (file_path : str , device : str ):
16+ def unflatten_tensor_state_dict (
17+ tensors_data_dict : Dict [str , Any ],
18+ metadata_dict : Dict [str , Any ],
19+ ):
1820 """
19- Load a dictionary of tensor subclasses from a safetensors file.
20-
21- For torch.Tensors, we load:
22- - _data: the tensor data
23- - _type: the tensor type
24-
25- For Float8Tensor, we load:
26- - tensor_data: qdata and scale
27- - tensor_attributes:
28- - block_size
29- - mm_config
30- - hp_value_lb
31- - hp_value_ub
32- - act_quant_kwargs
33- - kernel_preference
34- - dtype
21+ Reconstructs tensor subclass state dict from provided torch.Tensor data and metadata
22+ This function is used after loading in previously saved model state dict (using safetensors.save_file) to reconstruct tensor subclass structure
23+
24+ For example, given a previously flattened tensors_data_dict and metadata_dict:
25+ tensors_data_dict = {
26+ '0.weight:qdata': torch.Tensor(...),
27+ '0.weight:scale': torch.Tensor(...),
28+ '0.bias:_data': torch.Tensor(...),
29+ }
30+ metadata_dict = {
31+ '0.weight': {
32+ '_type': 'Float8Tensor',
33+ '_data': {
34+ 'block_size': [1,32],
35+ ...
36+ }
37+ }
38+ '0.bias': {
39+ '_type': 'torch.Tensor',
40+ }
41+ 'tensor_names': ['0.weight', '0.bias']
42+ }
43+
44+ We recover the structure of the original state dict:
45+ tensor_dict = {
46+ '0.weight': Float8Tensor(
47+ qdata=torch.Tensor(...),
48+ scale=torch.Tensor(...),
49+ block_size=[1,32],
50+ ...),
51+ '0.bias': torch.Tensor(...),
52+ }
3553
3654 Args:
37- file_path: Path to the safetensors file
55+ tensors_data_dict: a dictionary from "tensor_name:tensor_data_attribute_name" to flattened torch.Tensor data for tensor subclass instance
56+ metadata_dict: a dictionary from "tensor_name" to another dictionary that contains type and attributes for tensor subclass instance
3857
3958 Returns:
4059 Dictionary of reconstructed tensor subclasses
4160 """
42- loaded_tensors = load_file (file_path , device )
43-
44- with open (file_path , "rb" ) as f :
45- import struct
46-
47- header_size = struct .unpack ("<Q" , f .read (8 ))[0 ]
48- header_bytes = f .read (header_size )
49- header = json .loads (header_bytes )
50- metadata = header .get ("__metadata__" , {})
61+ combined_data = {** tensors_data_dict , ** metadata_dict }
5162
52- if "tensor_names" not in metadata :
63+ if "tensor_names" not in metadata_dict :
5364 raise ValueError ("No tensors found" )
5465
55- tensor_names = json .loads (metadata ["tensor_names" ])
66+ tensor_names = json .loads (metadata_dict ["tensor_names" ])
5667 result = {}
5768
5869 for tensor_name in tensor_names :
5970 tensor_tensors = {}
60- for key , value in loaded_tensors .items ():
71+ for key , value in combined_data .items ():
6172 if key .startswith (f"{ tensor_name } :" ):
6273 # Remove the prefix
6374 tensor_tensors [key [len (tensor_name ) + 1 :]] = value
6475
65- tensor_metadata = json .loads (metadata .get (tensor_name ))
76+ tensor_metadata = json .loads (metadata_dict .get (tensor_name ))
6677 tensor_type = tensor_metadata .get ("_type" )
6778
6879 if tensor_type == Float8Tensor .__name__ :
@@ -73,54 +84,69 @@ def load_tensor_state_dict(file_path: str, device: str):
7384 else :
7485 raise ValueError (f"Unsupported tensor type: { tensor_type } " )
7586
76- logger .info (
77- f"Loaded { len (tensor_names )} tensor subclasses from { file_path } with metadata"
78- )
7987 return result
8088
8189
82- def save_tensor_state_dict (
83- tensor_dict : Dict [str , Dict [str , torch .Tensor ]],
84- file_path : str ,
90+ def flatten_tensor_state_dict (
91+ tensors_dict : Dict [str , Dict [str , torch .Tensor ]],
8592):
8693 """
87- Save a dictionary of tensor subclasses with appropriate metadata.
88-
89- For torch.Tensors, we save:
90- - _data: the tensor data
91- - _type: the tensor type
92-
93- For Float8Tensor, we save:
94- - tensor_data:
95- - qdata
96- - scale
97- - tensor_attributes:
98- - block_size
99- - mm_config
100- - hp_value_lb
101- - hp_value_ub
102- - act_quant_kwargs
103- - kernel_preference
104- - dtype
94+ Flattens a dictionary of tensor subclasses so that it is compatible with safetensors.save_file
95+ We disconstruct tensor subclass structure into torch.Tensor data and metadata
96+
97+ For example, given something like:
98+ tensor_dict = {
99+ '0.weight': Float8Tensor(
100+ qdata=torch.Tensor(...),
101+ scale=torch.Tensor(...),
102+ block_size=[1,32],
103+ ...),
104+ '0.bias': torch.Tensor(...),
105+ }
106+
107+ We flatten this to:
108+ tensors_data = {
109+ '0.weight:qdata': torch.Tensor(...),
110+ '0.weight:scale': torch.Tensor(...),
111+ '0.bias:_data': torch.Tensor(...),
112+ }
113+ metadata = {
114+ '0.weight': {
115+ '_type': 'Float8Tensor',
116+ '_data': {
117+ 'block_size': [1,32],
118+ ...
119+ }
120+ }
121+ '0.bias': {
122+ '_type': 'torch.Tensor',
123+ }
124+ 'tensor_names': ['0.weight', '0.bias']
125+ }
105126
106127 Args:
107128 tensor_dict: Dictionary of tensor subclasses to save, with keys as tensor names
108- file_path: Path where to save the tensors
129+
130+ Returns:
131+ A tuple of (tensors_data, metadata) where
132+ tensors_data: Dict[str, torch.Tensor] contains the tensor data
133+ metadata: Dict[str, str] contains accompanying metadata from tensor subclass
134+ This structure is compatible with safetensors.save_file
109135 """
110136
111- combined_metadata = {}
112- combined_tensors_dict = {}
137+ metadata = {}
138+ tensors_data = {}
113139
114- for tensor_name , tensor in tensor_dict .items ():
140+ for tensor_name , tensor in tensors_dict .items ():
115141 if isinstance (tensor , Float8Tensor ):
116- tensors_dict = {}
142+ tensor_dict = {}
117143 for tensor_data_name in tensor .tensor_data_names :
118- tensors_dict [tensor_data_name ] = getattr (tensor , tensor_data_name )
144+ tensor_dict [tensor_data_name ] = getattr (tensor , tensor_data_name )
119145
120- metadata = json .dumps (tensor , cls = Float8TensorAttributeJSONEncoder )
146+ tensor_metadata = json .dumps (tensor , cls = Float8TensorAttributeJSONEncoder )
121147 elif type (tensor ) is torch .Tensor :
122- tensors_dict = {"_data" : tensor }
123- metadata = json .dumps ({"_type" : torch .Tensor .__name__ })
148+ tensor_dict = {"_data" : tensor }
149+ tensor_metadata = json .dumps ({"_type" : torch .Tensor .__name__ })
124150 else :
125151 raise ValueError (f"Unsupported tensor type: { type (tensor )} " )
126152
@@ -129,15 +155,11 @@ def save_tensor_state_dict(
129155 f"{ tensor_name } :{ key } " : (
130156 value .detach ().clone () if isinstance (value , torch .Tensor ) else value
131157 )
132- for key , value in tensors_dict .items ()
158+ for key , value in tensor_dict .items ()
133159 }
134160
135- combined_metadata [tensor_name ] = metadata
136- combined_tensors_dict .update (prefixed_tensors_dict )
137-
138- combined_metadata ["tensor_names" ] = json .dumps (list (tensor_dict .keys ()))
161+ metadata [tensor_name ] = tensor_metadata
162+ tensors_data .update (prefixed_tensors_dict )
139163
140- save_file (combined_tensors_dict , file_path , metadata = combined_metadata )
141- logger .info (
142- f"Saved { len (tensor_dict )} tensor subclasses to { file_path } with metadata"
143- )
164+ metadata ["tensor_names" ] = json .dumps (list (tensors_dict .keys ()))
165+ return tensors_data , metadata
0 commit comments