Skip to content

Commit 392a504

Browse files
committed
enable torchao safetensors support
1 parent 9643eaf commit 392a504

File tree

3 files changed

+36
-44
lines changed

3 files changed

+36
-44
lines changed

src/transformers/modeling_utils.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,8 @@
166166
else:
167167
IS_SAGEMAKER_MP_POST_1_10 = False
168168

169+
from torchao.prototype.safetensors.safetensors_utils import is_metadata_dict_torchao
170+
169171

170172
logger = logging.get_logger(__name__)
171173

@@ -496,10 +498,9 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
496498

497499
def load_state_dict(
498500
checkpoint_file: Union[str, os.PathLike],
499-
is_quantized: bool = False, #change to hf_quantizer (default is none)
501+
is_quantized: bool = False,
500502
map_location: Optional[Union[str, torch.device]] = "cpu",
501503
weights_only: bool = True,
502-
hf_quantizer: Optional[HfQuantizer] = None,
503504
):
504505
"""
505506
Reads a `safetensor` or a `.bin` checkpoint file. We load the checkpoint on "cpu" by default.
@@ -596,7 +597,7 @@ def set_initialized_submodules(model, state_dict_keys):
596597
return not_initialized_submodules
597598

598599

599-
def _end_ptr(tensor: torch.Tensor) -> int:
600+
def _end_ptr(tensor: torch.Tensor) -> int:
600601
# extract the end of the pointer if the tensor is a slice of a bigger tensor
601602
if tensor.nelement():
602603
stop = tensor.view(-1)[-1].data_ptr() + tensor.element_size()
@@ -728,7 +729,7 @@ def _load_state_dict_into_meta_model(
728729
keep_in_fp32_regex: Optional[re.Pattern] = None,
729730
unexpected_keys: Optional[list[str]] = None, # passing `unexpected` for cleanup from quantization items
730731
device_mesh: Optional["torch.distributed.device_mesh.DeviceMesh"] = None,
731-
metadata: Optional[dict] = None
732+
metadata: Optional[dict] = None,
732733
) -> tuple[Optional[dict], Optional[dict]]:
733734
"""Load parameters from `meta_state_dict` into the model. The parameters of the `meta_state_dict` are on the meta
734735
device in order to easily infer the shapes and dtypes that they will have. Then proper parameters are then loaded
@@ -746,14 +747,13 @@ def _load_state_dict_into_meta_model(
746747
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
747748
QuantizationMethod.HQQ,
748749
QuantizationMethod.BITS_AND_BYTES,
749-
QuantizationMethod.TORCHAO
750+
QuantizationMethod.TORCHAO,
750751
}
751752
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_ao
752753
file_pointer = None
753754
if is_meta_state_dict:
754755
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
755-
756-
if hf_quantizer and hasattr(hf_quantizer, "transform_state_dict") and metadata:
756+
if hf_quantizer and hasattr(hf_quantizer, "transform_state_dict") and is_metadata_dict_torchao(metadata):
757757
state_dict = hf_quantizer.transform_state_dict(state_dict, metadata)
758758

759759
for param_name, empty_param in state_dict.items():
@@ -787,8 +787,7 @@ def _load_state_dict_into_meta_model(
787787
device_map=device_map,
788788
)
789789
)
790-
):
791-
# In this case, the param is already on the correct device!
790+
): # In this case, the param is already on the correct device!
792791
shard_and_distribute_module(
793792
model,
794793
param,
@@ -938,7 +937,7 @@ def load_shard_file(args):
938937
# If shard_file is "", we use the existing state_dict instead of loading it
939938
if shard_file != "":
940939
state_dict = load_state_dict(
941-
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only, hf_quantizer=hf_quantizer
940+
shard_file, is_quantized=is_quantized, map_location=map_location, weights_only=weights_only
942941
)
943942

944943
# Fix the key names
@@ -3987,11 +3986,11 @@ def save_pretrained(
39873986
and hf_quantizer.is_serializable(safe_serialization=safe_serialization)
39883987
)
39893988

3990-
# if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3991-
# raise ValueError(
3992-
# f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
3993-
# " the logger on the traceback to understand the reason why the quantized model is not serializable."
3994-
# )
3989+
if hf_quantizer is not None and not _hf_peft_config_loaded and not quantization_serializable:
3990+
raise ValueError(
3991+
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
3992+
" the logger on the traceback to understand the reason why the quantized model is not serializable."
3993+
)
39953994

39963995
if "save_config" in kwargs:
39973996
warnings.warn(
@@ -4020,9 +4019,9 @@ def save_pretrained(
40204019
repo_id = self._create_repo(repo_id, **kwargs)
40214020
files_timestamps = self._get_files_timestamps(save_directory)
40224021

4022+
metadata = {}
40234023
if hf_quantizer is not None:
40244024
state_dict = hf_quantizer.get_state_dict(self)
4025-
metadata = {}
40264025
if isinstance(state_dict, tuple):
40274026
state_dict, metadata = state_dict
40284027

@@ -4171,8 +4170,7 @@ def save_pretrained(
41714170
else:
41724171
ptrs[id_tensor_storage(tensor)].append(name)
41734172

4174-
# shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
4175-
shared_ptrs = {}
4173+
shared_ptrs = {ptr: names for ptr, names in ptrs.items() if len(names) > 1}
41764174

41774175
# Recursively descend to find tied weight keys
41784176
_tied_weights_keys = _get_tied_weight_keys(self)
@@ -5095,6 +5093,7 @@ def from_pretrained(
50955093
)
50965094

50975095
from_pt = not (from_tf | from_flax)
5096+
50985097
if from_pt:
50995098
if gguf_file:
51005099
from .modeling_gguf_pytorch_utils import load_gguf_checkpoint
@@ -5113,7 +5112,6 @@ def from_pretrained(
51135112
)
51145113

51155114
config.name_or_path = pretrained_model_name_or_path
5116-
51175115
model_init_context = cls.get_init_context(is_quantized, _is_ds_init_called)
51185116
config = copy.deepcopy(config) # We do not want to modify the config inplace in from_pretrained.
51195117
with ContextManagers(model_init_context):
@@ -5448,7 +5446,7 @@ def _load_pretrained_model(
54485446
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
54495447
QuantizationMethod.HQQ,
54505448
QuantizationMethod.BITS_AND_BYTES,
5451-
QuantizationMethod.TORCHAO
5449+
QuantizationMethod.TORCHAO,
54525450
}
54535451

54545452
# Get all the keys of the state dicts that we have to initialize the model
@@ -5567,7 +5565,6 @@ def _load_pretrained_model(
55675565
if sharded_metadata is None:
55685566
weight_map = dict.fromkeys(checkpoint_keys, checkpoint_files[0])
55695567
else:
5570-
# weight file full path
55715568
folder = os.path.sep.join(checkpoint_files[0].split(os.path.sep)[:-1])
55725569
# Fix the weight map keys according to the key mapping
55735570
weight_map = {

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,9 @@
3535
import torch
3636
import torch.nn as nn
3737

38-
from torchao.quantization import Float8Tensor
3938

40-
from torchao.prototype.safetensors.safetensors_support import save_tensor_state_dict, load_tensor_state_dict
39+
from torchao.prototype.safetensors.safetensors_support import flatten_tensor_state_dict, unflatten_tensor_state_dict
40+
4141

4242
logger = logging.get_logger(__name__)
4343

@@ -142,8 +142,7 @@ def update_dtype(self, dtype):
142142
return dtype
143143

144144
def get_state_dict(self, model):
145-
return save_tensor_state_dict(model.state_dict())
146-
145+
return flatten_tensor_state_dict(model.state_dict())
147146

148147
def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
149148
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
@@ -228,7 +227,6 @@ def check_quantized_param(
228227
_QUANTIZABLE.append(torch.nn.Embedding)
229228
return isinstance(module, tuple(_QUANTIZABLE)) and (tensor_name == "weight")
230229

231-
232230
def create_quantized_param(
233231
self,
234232
model: "PreTrainedModel",
@@ -289,7 +287,7 @@ def create_quantized_param(
289287
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
290288

291289
def transform_state_dict(self, tensor_data, metadata):
292-
return load_tensor_state_dict(tensor_data=tensor_data, provided_metadata=metadata)
290+
return unflatten_tensor_state_dict(tensor_data, metadata)
293291

294292
def _process_model_after_weight_loading(self, model, **kwargs):
295293
"""No process required for torchao quantized model"""
@@ -309,10 +307,13 @@ def _process_model_after_weight_loading(self, model, **kwargs):
309307

310308
def is_serializable(self, safe_serialization=None) -> bool:
311309
if safe_serialization:
310+
from torchao.quantization import Float8WeightOnlyConfig
311+
312312
logger.warning(
313-
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False"
313+
"torchao quantized model only supports safe serialization for Float8WeightOnlyConfig, please set `safe_serialization` to False if you are using a different config"
314314
)
315-
return False
315+
316+
return isinstance(self.quantization_config.quant_type, Float8WeightOnlyConfig)
316317
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
317318
"0.25.0"
318319
)

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ def test_autoquant(self):
399399

400400
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
401401

402-
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
402+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
403403
output = quantized_model.generate(
404404
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
405405
)
@@ -412,26 +412,21 @@ class TorchAoSerializationTest(unittest.TestCase):
412412
input_text = "What are we having for dinner?"
413413
max_new_tokens = 10
414414
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
415-
quant_scheme = "int4_weight_only"
416-
quant_scheme_kwargs = (
417-
{"group_size": 32, "layout": Int4CPULayout()}
418-
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
419-
else {"group_size": 32}
420-
)
421415
device = "cpu"
422416

423417
# called only once for all test in this class
424418
@classmethod
425419
def setUpClass(cls):
426420
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
427-
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
421+
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
428422

429423
def setUp(self):
430-
self.quant_config = TorchAoConfig(self.quant_scheme, **self.quant_scheme_kwargs)
431-
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
424+
from torchao.quantization import Float8WeightOnlyConfig
425+
426+
self.quant_config = TorchAoConfig(Float8WeightOnlyConfig())
432427
self.quantized_model = AutoModelForCausalLM.from_pretrained(
433428
self.model_name,
434-
dtype=dtype,
429+
dtype=torch.bfloat16,
435430
device_map=self.device,
436431
quantization_config=self.quant_config,
437432
)
@@ -451,12 +446,11 @@ def check_serialization_expected_output(self, device, expected_output):
451446
"""
452447
Test if we can serialize and load/infer the model again on the same device
453448
"""
454-
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
449+
dtype = torch.bfloat16
455450
with tempfile.TemporaryDirectory() as tmpdirname:
456-
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
451+
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=True)
457452
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, dtype=dtype, device_map=device)
458453
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
459-
460454
output = loaded_quantized_model.generate(**input_ids, max_new_tokens=self.max_new_tokens)
461455
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), expected_output)
462456

@@ -511,7 +505,7 @@ def setUpClass(cls):
511505
EXPECTED_OUTPUTS = Expectations(
512506
{
513507
("xpu", 3): "What are we having for dinner?\n\nJessica: (smiling)",
514-
("cuda", 7): "What are we having for dinner?\n- 1. What is the temperature outside",
508+
("cuda", 7): "What are we having for dinner?\n\nJessica: (smiling)",
515509
}
516510
)
517511
# fmt: on

0 commit comments

Comments
 (0)