Skip to content

Commit 6c79b56

Browse files
committed
enable torchao safetensors support
1 parent 75976a6 commit 6c79b56

File tree

4 files changed

+121
-52
lines changed

4 files changed

+121
-52
lines changed

src/transformers/modeling_utils.py

Lines changed: 18 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,9 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
496496

497497
def load_state_dict(
498498
checkpoint_file: Union[str, os.PathLike],
499-
is_quantized: bool = False, #change to hf_quantizer (default is none)
499+
is_quantized: bool = False,
500500
map_location: Optional[Union[str, torch.device]] = "cpu",
501501
weights_only: bool = True,
502-
hf_quantizer: Optional[HfQuantizer] = None,
503502
):
504503
"""
505504
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -596,7 +595,7 @@ def set_initialized_submodules(model, state_dict_keys):
596595
return not_initialized_submodules
597596

598597

599-
def _end_ptr(tensor: torch.Tensor) -> int:
598+
def _end_ptr(tensor: torch.Tensor) -> int:
600599
# extract the end of the pointer if the tensor is a slice of a bigger tensor
601600
if tensor.nelement():
602601
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
@@ -728,7 +727,6 @@ def _load_state_dict_into_meta_model(
728727
keep_in_fp32_regex: Optional[re.Pattern] = None,
729728
unexpected_keys: Optional[list[str]] = None, # passing `unexpected` for cleanup from quantization items
730729
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
731-
metadata: Optional[dict] = None
732730
) -> tuple[Optional[dict], Optional[dict]]:
733731
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
734732
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -746,16 +744,13 @@ def _load_state_dict_into_meta_model(
746744
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
747745
QuantizationMethod.HQQ,
748746
QuantizationMethod.BITS_AND_BYTES,
749-
QuantizationMethod.TORCHAO
747+
QuantizationMethod.TORCHAO,
750748
}
751749
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_ao
752750
file_pointer = None
753751
if is_meta_state_dict:
754752
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
755753

756-
if hf_quantizer and hasattr(hf_quantizer, "transform_state_dict") and metadata:
757-
state_dict = hf_quantizer.transform_state_dict(state_dict, metadata)
758-
759754
for param_name, empty_param in state_dict.items():
760755
if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling
761756
continue
@@ -787,8 +782,7 @@ def _load_state_dict_into_meta_model(
787782
device_map=device_map,
788783
)
789784
)
790-
):
791-
# In this case, the param is already on the correct device!
785+
): # In this case, the param is already on the correct device!
792786
shard_and_distribute_module(
793787
model,
794788
param,
@@ -938,7 +932,7 @@ def load_shard_file(args):
938932
# If shard_file is "", we use the existing state_dict instead of loading it
939933
if shard_file != "":
940934
state_dict = load_state_dict(
941-
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only, hf_quantizer=hf_quantizer
935+
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
942936
)
943937

944938
# Fix the key names
@@ -948,6 +942,9 @@ def load_shard_file(args):
948942
with safe_open(shard_file, framework="pt") as f:
949943
metadata = f.metadata()
950944

945+
if hf_quantizer:
946+
state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, metadata)
947+
951948
error_msgs = []
952949

953950
if is_deepspeed_zero3_enabled() and not is_quantized:
@@ -970,7 +967,6 @@ def load_shard_file(args):
970967
keep_in_fp32_regex=keep_in_fp32_regex,
971968
unexpected_keys=unexpected_keys,
972969
device_mesh=device_mesh,
973-
metadata=metadata,
974970
)
975971

976972
return error_msgs, disk_offload_index, cpu_offload_index
@@ -3994,11 +3990,11 @@ def save_pretrained(
39943990
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
39953991
)
39963992

3997-
# if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3998-
# raise ValueError(
3999-
# f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
4000-
# " the logger on the traceback to understand the reason why the quantized model is not serializable."
4001-
# )
3993+
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3994+
raise ValueError(
3995+
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
3996+
" the logger on the traceback to understand the reason why the quantized model is not serializable."
3997+
)
40023998

40033999
if "save_config" in kwargs:
40044000
warnings.warn(
@@ -4029,10 +4025,8 @@ def save_pretrained(
40294025

40304026
metadata = {}
40314027
if hf_quantizer is not None:
4032-
state_dict = hf_quantizer.get_state_dict(self)
4033-
metadata = {}
4034-
if isinstance(state_dict, tuple):
4035-
state_dict, metadata = state_dict
4028+
state_dict, metadata = hf_quantizer.get_state_dict_and_metadata(self, safe_serialization)
4029+
metadata["format"] = "pt"
40364030

40374031
# Only save the model itself if we are using distributed training
40384032
model_to_save = unwrap_model(self)
@@ -4180,8 +4174,7 @@ def save_pretrained(
41804174
else:
41814175
ptrs[id_tensor_storage(tensor)].append(name)
41824176

4183-
# shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
4184-
shared_ptrs = {}
4177+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
41854178

41864179
# Recursively descend to find tied weight keys
41874180
_tied_weights_keys = _get_tied_weight_keys(self)
@@ -4312,7 +4305,6 @@ def save_pretrained(
43124305
if safe_serialization:
43134306
# At some point we will need to deal better with save_function (used for TPU and other distributed
43144307
# joyfulness), but for now this enough.
4315-
metadata["format"] = "pt"
43164308
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
43174309
else:
43184310
save_function(shard, os.path.join(save_directory, shard_file))
@@ -4808,7 +4800,6 @@ def from_pretrained(
48084800

48094801
if distributed_config is not None:
48104802
tp_plan = "auto"
4811-
48124803
# Not used anymore -- remove them from the kwargs
48134804
_ = kwargs.pop("resume_download", None)
48144805
_ = kwargs.pop("mirror", None)
@@ -4960,7 +4951,6 @@ def from_pretrained(
49604951
"Using a `device_map`, `tp_plan`, `torch.device` context manager or setting `torch.set_default_device(device)` "
49614952
"requires `accelerate`. You can install it with `pip install accelerate`"
49624953
)
4963-
49644954
# handling bnb config from kwargs, remove after `load_in_{4/8}bit` deprecation.
49654955
if load_in_4bit or load_in_8bit:
49664956
if quantization_config is not None:
@@ -5030,7 +5020,6 @@ def from_pretrained(
50305020
"(*.safetensors) nor a safetensors index file (*.safetensors.index.json): "
50315021
f"{transformers_explicit_filename}"
50325022
)
5033-
50345023
hf_quantizer, config, dtype, device_map = get_hf_quantizer(
50355024
config, quantization_config, dtype, from_tf, from_flax, device_map, weights_only, user_agent
50365025
)
@@ -5103,6 +5092,7 @@ def from_pretrained(
51035092
)
51045093

51055094
from_pt = not (from_tf | from_flax)
5095+
51065096
if from_pt:
51075097
if gguf_file:
51085098
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
@@ -5121,7 +5111,6 @@ def from_pretrained(
51215111
)
51225112

51235113
config.name_or_path = pretrained_model_name_or_path
5124-
51255114
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
51265115
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
51275116
with ContextManagers(model_init_context):
@@ -5449,7 +5438,7 @@ def _load_pretrained_model(
54495438
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
54505439
QuantizationMethod.HQQ,
54515440
QuantizationMethod.BITS_AND_BYTES,
5452-
QuantizationMethod.TORCHAO
5441+
QuantizationMethod.TORCHAO,
54535442
}
54545443

54555444
# Get all the keys of the state dicts that we have to initialize the model
@@ -5568,7 +5557,6 @@ def _load_pretrained_model(
55685557
if sharded_metadata is None:
55695558
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
55705559
else:
5571-
# weight file full path
55725560
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
55735561
# Fix the weight map keys according to the key mapping
55745562
weight_map = {

src/transformers/quantizers/base.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,9 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False):
342342
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
343343
return None, {}
344344

345+
def update_state_dict_with_metadata(self, state_dict, metadata):
346+
return state_dict
347+
345348
@abstractmethod
346349
def _process_model_before_weight_loading(self, model, **kwargs): ...
347350

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,16 @@
3535
import torch
3636
import torch.nn as nn
3737

38-
from torchao.quantization import Float8Tensor
38+
if is_torchao_available():
39+
import torchao
40+
41+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
42+
from torchao.prototype.safetensors.safetensors_support import (
43+
flatten_tensor_state_dict,
44+
unflatten_tensor_state_dict,
45+
)
46+
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
3947

40-
from torchao.prototype.safetensors.safetensors_support import save_tensor_state_dict, load_tensor_state_dict
4148

4249
logger = logging.get_logger(__name__)
4350

@@ -85,6 +92,13 @@ def _linear_extra_repr(self):
8592
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
8693

8794

95+
if is_torchao_available():
96+
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
97+
torchao.quantization.Float8WeightOnlyConfig,
98+
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
99+
]
100+
101+
88102
class TorchAoHfQuantizer(HfQuantizer):
89103
"""
90104
Quantizer for torchao: https://github.com/pytorch/ao/
@@ -141,9 +155,19 @@ def update_dtype(self, dtype):
141155
dtype = torch.float32
142156
return dtype
143157

144-
def get_state_dict(self, model):
145-
return save_tensor_state_dict(model.state_dict())
146-
158+
def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool] = False):
159+
'''
160+
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
161+
the safetensors format.
162+
'''
163+
if (
164+
type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS
165+
and safe_serialization
166+
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0")
167+
):
168+
return flatten_tensor_state_dict(model.state_dict())
169+
else:
170+
return super().get_state_dict_and_metadata(model)
147171

148172
def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
149173
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
@@ -228,7 +252,6 @@ def check_quantized_param(
228252
_QUANTIZABLE.append(torch.nn.Embedding)
229253
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
230254

231-
232255
def create_quantized_param(
233256
self,
234257
model: "PreTrainedModel",
@@ -288,8 +311,17 @@ def create_quantized_param(
288311

289312
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
290313

291-
def transform_state_dict(self, tensor_data, metadata):
292-
return load_tensor_state_dict(tensor_data=tensor_data, provided_metadata=metadata)
314+
def update_state_dict_with_metadata(self, state_dict, metadata):
315+
'''
316+
If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict
317+
from the provided state_dict and metadata.
318+
'''
319+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0") and is_metadata_torchao(
320+
metadata
321+
):
322+
return unflatten_tensor_state_dict(state_dict, metadata)
323+
else:
324+
return state_dict
293325

294326
def _process_model_after_weight_loading(self, model, **kwargs):
295327
"""No process required for torchao quantized model"""
@@ -309,10 +341,15 @@ def _process_model_after_weight_loading(self, model, **kwargs):
309341

310342
def is_serializable(self, safe_serialization=None) -> bool:
311343
if safe_serialization:
312-
logger.warning(
313-
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False"
344+
_is_torchao_serializable = (
345+
type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS
314346
)
315-
return False
347+
if not _is_torchao_serializable:
348+
logger.warning(
349+
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, please set `safe_serialization` to False if you are using a different config"
350+
)
351+
return _is_torchao_serializable
352+
316353
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
317354
"0.25.0"
318355
)

0 commit comments

Comments
 (0)