Skip to content

Commit 9643eaf

Browse files
committed
enable torchao safetensors
1 parent 34595cf commit 9643eaf

File tree

2 files changed

+50
-18
lines changed

2 files changed

+50
-18
lines changed

src/transformers/modeling_utils.py

Lines changed: 38 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -496,9 +496,10 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
496496

497497
def load_state_dict(
498498
checkpoint_file: Union[str, os.PathLike],
499-
is_quantized: bool = False,
499+
is_quantized: bool = False, #change to hf_quantizer (default is none)
500500
map_location: Optional[Union[str, torch.device]] = "cpu",
501501
weights_only: bool = True,
502+
hf_quantizer: Optional[HfQuantizer] = None,
502503
):
503504
"""
504505
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -595,7 +596,7 @@ def set_initialized_submodules(model, state_dict_keys):
595596
return not_initialized_submodules
596597

597598

598-
def _end_ptr(tensor: torch.Tensor) -> int:
599+
def _end_ptr(tensor: torch.Tensor) -> int:
599600
# extract the end of the pointer if the tensor is a slice of a bigger tensor
600601
if tensor.nelement():
601602
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
@@ -727,6 +728,7 @@ def _load_state_dict_into_meta_model(
727728
keep_in_fp32_regex: Optional[re.Pattern] = None,
728729
unexpected_keys: Optional[list[str]] = None, # passing `unexpected` for cleanup from quantization items
729730
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
731+
metadata: Optional[dict] = None
730732
) -> tuple[Optional[dict], Optional[dict]]:
731733
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
732734
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -741,15 +743,19 @@ def _load_state_dict_into_meta_model(
741743
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
742744

743745
is_quantized = hf_quantizer is not None
744-
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
746+
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
745747
QuantizationMethod.HQQ,
746748
QuantizationMethod.BITS_AND_BYTES,
749+
QuantizationMethod.TORCHAO
747750
}
748-
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
751+
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_ao
749752
file_pointer = None
750753
if is_meta_state_dict:
751754
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
752755

756+
if hf_quantizer and hasattr(hf_quantizer, "transform_state_dict") and metadata:
757+
state_dict = hf_quantizer.transform_state_dict(state_dict, metadata)
758+
753759
for param_name, empty_param in state_dict.items():
754760
if param_name not in expected_keys: # when loading from ckpt, we skip param if doesnt exist in modeling
755761
continue
@@ -781,7 +787,8 @@ def _load_state_dict_into_meta_model(
781787
device_map=device_map,
782788
)
783789
)
784-
): # In this case, the param is already on the correct device!
790+
):
791+
# In this case, the param is already on the correct device!
785792
shard_and_distribute_module(
786793
model,
787794
param,
@@ -887,7 +894,7 @@ def load_shard_file(args):
887894
shard_file,
888895
state_dict,
889896
disk_only_shard_files,
890-
is_hqq_or_bnb,
897+
is_hqq_or_bnb_or_ao,
891898
is_quantized,
892899
device_map,
893900
hf_quantizer,
@@ -913,7 +920,7 @@ def load_shard_file(args):
913920
map_location = "cpu"
914921
if (
915922
shard_file.endswith(".safetensors")
916-
and not is_hqq_or_bnb
923+
and not is_hqq_or_bnb_or_ao
917924
and not (is_deepspeed_zero3_enabled() and not is_quantized)
918925
):
919926
map_location = "meta"
@@ -931,11 +938,15 @@ def load_shard_file(args):
931938
# If shard_file is "", we use the existing state_dict instead of loading it
932939
if shard_file != "":
933940
state_dict = load_state_dict(
934-
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
941+
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only, hf_quantizer=hf_quantizer
935942
)
936943

937944
# Fix the key names
938945
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
946+
metadata = None
947+
if shard_file.endswith(".safetensors") and is_safetensors_available():
948+
with safe_open(shard_file, framework="pt") as f:
949+
metadata = f.metadata()
939950

940951
error_msgs = []
941952

@@ -959,6 +970,7 @@ def load_shard_file(args):
959970
keep_in_fp32_regex=keep_in_fp32_regex,
960971
unexpected_keys=unexpected_keys,
961972
device_mesh=device_mesh,
973+
metadata=metadata,
962974
)
963975

964976
return error_msgs, disk_offload_index, cpu_offload_index
@@ -3975,11 +3987,11 @@ def save_pretrained(
39753987
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
39763988
)
39773989

3978-
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3979-
raise ValueError(
3980-
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
3981-
" the logger on the traceback to understand the reason why the quantized model is not serializable."
3982-
)
3990+
# if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3991+
# raise ValueError(
3992+
# f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
3993+
# " the logger on the traceback to understand the reason why the quantized model is not serializable."
3994+
# )
39833995

39843996
if "save_config" in kwargs:
39853997
warnings.warn(
@@ -4010,6 +4022,10 @@ def save_pretrained(
40104022

40114023
if hf_quantizer is not None:
40124024
state_dict = hf_quantizer.get_state_dict(self)
4025+
metadata = {}
4026+
if isinstance(state_dict, tuple):
4027+
state_dict, metadata = state_dict
4028+
40134029
# Only save the model itself if we are using distributed training
40144030
model_to_save = unwrap_model(self)
40154031
# save the string version of dtype to the config, e.g. convert torch.float32 => "float32"
@@ -4155,7 +4171,8 @@ def save_pretrained(
41554171
else:
41564172
ptrs[id_tensor_storage(tensor)].append(name)
41574173

4158-
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
4174+
# shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
4175+
shared_ptrs = {}
41594176

41604177
# Recursively descend to find tied weight keys
41614178
_tied_weights_keys = _get_tied_weight_keys(self)
@@ -4286,7 +4303,8 @@ def save_pretrained(
42864303
if safe_serialization:
42874304
# At some point we will need to deal better with save_function (used for TPU and other distributed
42884305
# joyfulness), but for now this enough.
4289-
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata={"format": "pt"})
4306+
metadata["format"] = "pt"
4307+
safe_save_file(shard, os.path.join(save_directory, shard_file), metadata=metadata)
42904308
else:
42914309
save_function(shard, os.path.join(save_directory, shard_file))
42924310

@@ -5077,7 +5095,6 @@ def from_pretrained(
50775095
)
50785096

50795097
from_pt = not (from_tf | from_flax)
5080-
50815098
if from_pt:
50825099
if gguf_file:
50835100
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
@@ -5096,6 +5113,7 @@ def from_pretrained(
50965113
)
50975114

50985115
config.name_or_path = pretrained_model_name_or_path
5116+
50995117
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
51005118
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
51015119
with ContextManagers(model_init_context):
@@ -5427,9 +5445,10 @@ def _load_pretrained_model(
54275445
QuantizationMethod.HQQ,
54285446
QuantizationMethod.QUARK,
54295447
}
5430-
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
5448+
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
54315449
QuantizationMethod.HQQ,
54325450
QuantizationMethod.BITS_AND_BYTES,
5451+
QuantizationMethod.TORCHAO
54335452
}
54345453

54355454
# Get all the keys of the state dicts that we have to initialize the model
@@ -5548,6 +5567,7 @@ def _load_pretrained_model(
55485567
if sharded_metadata is None:
55495568
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
55505569
else:
5570+
# weight file full path
55515571
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
55525572
# Fix the weight map keys according to the key mapping
55535573
weight_map = {
@@ -5602,7 +5622,7 @@ def _load_pretrained_model(
56025622
shard_file,
56035623
state_dict,
56045624
disk_only_shard_files,
5605-
is_hqq_or_bnb,
5625+
is_hqq_or_bnb_or_ao,
56065626
is_quantized,
56075627
device_map,
56085628
hf_quantizer,

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,10 @@
3535
import torch
3636
import torch.nn as nn
3737

38+
from torchao.quantization import Float8Tensor
39+
40+
from torchao.prototype.safetensors.safetensors_support import save_tensor_state_dict, load_tensor_state_dict
41+
3842
logger = logging.get_logger(__name__)
3943

4044

@@ -137,6 +141,10 @@ def update_dtype(self, dtype):
137141
dtype = torch.float32
138142
return dtype
139143

144+
def get_state_dict(self, model):
145+
return save_tensor_state_dict(model.state_dict())
146+
147+
140148
def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
141149
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
142150
from accelerate.utils import CustomDtype
@@ -220,6 +228,7 @@ def check_quantized_param(
220228
_QUANTIZABLE.append(torch.nn.Embedding)
221229
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
222230

231+
223232
def create_quantized_param(
224233
self,
225234
model: "PreTrainedModel",
@@ -279,6 +288,9 @@ def create_quantized_param(
279288

280289
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
281290

291+
def transform_state_dict(self, tensor_data, metadata):
292+
return load_tensor_state_dict(tensor_data=tensor_data, provided_metadata=metadata)
293+
282294
def _process_model_after_weight_loading(self, model, **kwargs):
283295
"""No process required for torchao quantized model"""
284296
if self.quantization_config.quant_type == "autoquant":

0 commit comments

Comments
 (0)