Skip to content

Commit 1d81247

Browse files
authored
[torchao safetensors] integrate torchao safetensors support with transformers (#40735)
* enable torchao safetensors * enable torchao safetensors support * add more version checking
1 parent b533cec commit 1d81247

File tree

4 files changed

+129
-18
lines changed

4 files changed

+129
-18
lines changed

src/transformers/modeling_utils.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -727,11 +727,12 @@ def _load_state_dict_into_meta_model(
727727
device_map_regex = "|".join([re.escape(k) for k in sorted(device_map.keys(), reverse=True)])
728728

729729
is_quantized = hf_quantizer is not None
730-
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
730+
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
731731
QuantizationMethod.HQQ,
732732
QuantizationMethod.BITS_AND_BYTES,
733+
QuantizationMethod.TORCHAO,
733734
}
734-
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb
735+
is_meta_state_dict = shard_file.endswith(".safetensors") and not is_hqq_or_bnb_or_ao
735736
file_pointer = None
736737
if is_meta_state_dict:
737738
file_pointer = safe_open(shard_file, framework="pt", device=tensor_device)
@@ -873,7 +874,7 @@ def load_shard_file(args):
873874
shard_file,
874875
state_dict,
875876
disk_only_shard_files,
876-
is_hqq_or_bnb,
877+
is_hqq_or_bnb_or_ao,
877878
is_quantized,
878879
device_map,
879880
hf_quantizer,
@@ -899,7 +900,7 @@ def load_shard_file(args):
899900
map_location = "cpu"
900901
if (
901902
shard_file.endswith(".safetensors")
902-
and not is_hqq_or_bnb
903+
and not is_hqq_or_bnb_or_ao
903904
and not (is_deepspeed_zero3_enabled() and not is_quantized)
904905
):
905906
map_location = "meta"
@@ -922,6 +923,13 @@ def load_shard_file(args):
922923

923924
# Fix the key names
924925
state_dict = {key_renaming_mapping[k]: v for k, v in state_dict.items() if k in key_renaming_mapping}
926+
metadata = None
927+
if shard_file.endswith(".safetensors") and is_safetensors_available():
928+
with safe_open(shard_file, framework="pt") as f:
929+
metadata = f.metadata()
930+
931+
if hf_quantizer:
932+
state_dict = hf_quantizer.update_state_dict_with_metadata(state_dict, metadata)
925933

926934
error_msgs = []
927935

@@ -5277,9 +5285,10 @@ def _load_pretrained_model(
52775285
QuantizationMethod.HQQ,
52785286
QuantizationMethod.QUARK,
52795287
}
5280-
is_hqq_or_bnb = is_quantized and hf_quantizer.quantization_config.quant_method in {
5288+
is_hqq_or_bnb_or_ao = is_quantized and hf_quantizer.quantization_config.quant_method in {
52815289
QuantizationMethod.HQQ,
52825290
QuantizationMethod.BITS_AND_BYTES,
5291+
QuantizationMethod.TORCHAO,
52835292
}
52845293

52855294
# Get all the keys of the state dicts that we have to initialize the model
@@ -5451,7 +5460,7 @@ def _load_pretrained_model(
54515460
shard_file,
54525461
state_dict,
54535462
disk_only_shard_files,
5454-
is_hqq_or_bnb,
5463+
is_hqq_or_bnb_or_ao,
54555464
is_quantized,
54565465
device_map,
54575466
hf_quantizer,

src/transformers/quantizers/base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -342,6 +342,10 @@ def get_state_dict_and_metadata(self, model, safe_serialization=False):
342342
"""Get state dict and metadata. Useful when we need to modify a bit the state dict due to quantization"""
343343
return None, {}
344344

345+
def update_state_dict_with_metadata(self, state_dict, metadata):
346+
"""Update state dict with metadata. Default behaviour returns state_dict"""
347+
return state_dict
348+
345349
@abstractmethod
346350
def _process_model_before_weight_loading(self, model, **kwargs): ...
347351

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,17 @@
3535
import torch
3636
import torch.nn as nn
3737

38+
if is_torchao_available():
39+
import torchao
40+
41+
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0"):
42+
from torchao.prototype.safetensors.safetensors_support import (
43+
flatten_tensor_state_dict,
44+
unflatten_tensor_state_dict,
45+
)
46+
from torchao.prototype.safetensors.safetensors_utils import is_metadata_torchao
47+
48+
3849
logger = logging.get_logger(__name__)
3950

4051

@@ -81,6 +92,15 @@ def _linear_extra_repr(self):
8192
return f"in_features={self.weight.shape[1]}, out_features={self.weight.shape[0]}, weight={weight}"
8293

8394

95+
if is_torchao_available():
96+
SUPPORTED_SAFE_SERIALIZATION_CONFIGS = [
97+
torchao.quantization.Float8WeightOnlyConfig,
98+
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
99+
]
100+
101+
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
102+
103+
84104
class TorchAoHfQuantizer(HfQuantizer):
85105
"""
86106
Quantizer for torchao: https://github.com/pytorch/ao/
@@ -137,6 +157,21 @@ def update_dtype(self, dtype):
137157
dtype = torch.float32
138158
return dtype
139159

160+
def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool] = False):
161+
"""
162+
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
163+
the safetensors format.
164+
"""
165+
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
166+
if TORCHAO_VERSION >= version.parse("0.14.0"):
167+
return flatten_tensor_state_dict(model.state_dict())
168+
else:
169+
raise RuntimeError(
170+
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
171+
)
172+
else:
173+
return super().get_state_dict_and_metadata(model)
174+
140175
def adjust_target_dtype(self, dtype: "torch.dtype") -> "torch.dtype":
141176
if version.parse(importlib.metadata.version("accelerate")) > version.parse("0.19.0"):
142177
from accelerate.utils import CustomDtype
@@ -279,6 +314,16 @@ def create_quantized_param(
279314

280315
quantize_(module, self.quantization_config.get_apply_tensor_subclass())
281316

317+
def update_state_dict_with_metadata(self, state_dict, metadata):
318+
"""
319+
If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict
320+
from the provided state_dict and metadata.
321+
"""
322+
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata):
323+
return unflatten_tensor_state_dict(state_dict, metadata)
324+
else:
325+
return super().update_state_dict_with_metadata(state_dict, metadata)
326+
282327
def _process_model_after_weight_loading(self, model, **kwargs):
283328
"""No process required for torchao quantized model"""
284329
if self.quantization_config.quant_type == "autoquant":
@@ -297,10 +342,17 @@ def _process_model_after_weight_loading(self, model, **kwargs):
297342

298343
def is_serializable(self, safe_serialization=None) -> bool:
299344
if safe_serialization:
300-
logger.warning(
301-
"torchao quantized model does not support safe serialization, please set `safe_serialization` to False"
302-
)
303-
return False
345+
_is_torchao_serializable = type(
346+
self.quantization_config.quant_type
347+
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
348+
if not _is_torchao_serializable:
349+
logger.warning(
350+
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
351+
and torchao version >= 0.14.0, please set `safe_serialization` to False for \
352+
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
353+
)
354+
return _is_torchao_serializable
355+
304356
_is_torchao_serializable = version.parse(importlib.metadata.version("huggingface_hub")) >= version.parse(
305357
"0.25.0"
306358
)

tests/quantization/torchao_integration/test_torchao.py

Lines changed: 54 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import unittest
1919

2020
from packaging import version
21+
from parameterized import parameterized
2122

2223
from transformers import AutoModelForCausalLM, AutoTokenizer, TorchAoConfig
2324
from transformers.testing_utils import (
@@ -37,6 +38,8 @@
3738
import torch
3839

3940
if is_torchao_available():
41+
import torchao
42+
4043
# renamed in torchao 0.7.0, please install the latest torchao
4144
from torchao.dtypes import (
4245
AffineQuantizedTensor,
@@ -135,7 +138,7 @@ class TorchAoTest(unittest.TestCase):
135138
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
136139
device = "cpu"
137140
quant_scheme_kwargs = (
138-
{"group_size": 32, "layout": Int4CPULayout()}
141+
{"group_size": 32, "layout": Int4CPULayout(), "version": 1}
139142
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
140143
else {"group_size": 32}
141144
)
@@ -225,6 +228,7 @@ def test_include_input_output_embeddings(self):
225228
weight_dtype=weight_dtype,
226229
granularity=granularity,
227230
mapping_type=mapping_type,
231+
version=1,
228232
)
229233
config = ModuleFqnToConfig(
230234
{"_default": None, "model.embed_tokens": embedding_config, "lm_head": embedding_config}
@@ -277,7 +281,7 @@ def test_per_module_config_skip(self):
277281
@require_torch_accelerator
278282
class TorchAoAcceleratorTest(TorchAoTest):
279283
device = torch_device
280-
quant_scheme_kwargs = {"group_size": 32}
284+
quant_scheme_kwargs = {"group_size": 32, "version": 1}
281285

282286
# called only once for all test in this class
283287
@classmethod
@@ -327,7 +331,7 @@ def test_int4wo_offload(self):
327331
"lm_head": 0,
328332
}
329333

330-
quant_config = TorchAoConfig("int4_weight_only", group_size=32)
334+
quant_config = TorchAoConfig("int4_weight_only", **self.quant_scheme_kwargs)
331335

332336
quantized_model = AutoModelForCausalLM.from_pretrained(
333337
self.model_name,
@@ -399,7 +403,7 @@ def test_autoquant(self):
399403

400404
check_autoquantized(self, quantized_model.model.layers[0].self_attn.v_proj)
401405

402-
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJane: (sighs)"
406+
EXPECTED_OUTPUT = "What are we having for dinner?\n\nJessica: (smiling)"
403407
output = quantized_model.generate(
404408
**input_ids, max_new_tokens=self.max_new_tokens, cache_implementation="static"
405409
)
@@ -414,7 +418,7 @@ class TorchAoSerializationTest(unittest.TestCase):
414418
model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
415419
quant_scheme = "int4_weight_only"
416420
quant_scheme_kwargs = (
417-
{"group_size": 32, "layout": Int4CPULayout()}
421+
{"group_size": 32, "layout": Int4CPULayout(), "version": 1}
418422
if is_torchao_available() and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.8.0")
419423
else {"group_size": 32}
420424
)
@@ -447,13 +451,13 @@ def test_original_model_expected_output(self):
447451

448452
self.assertEqual(self.tokenizer.decode(output[0], skip_special_tokens=True), self.EXPECTED_OUTPUT)
449453

450-
def check_serialization_expected_output(self, device, expected_output):
454+
def check_serialization_expected_output(self, device, expected_output, safe_serialization=False):
451455
"""
452456
Test if we can serialize and load/infer the model again on the same device
453457
"""
454458
dtype = torch.bfloat16 if self.quant_scheme == "int4_weight_only" else "auto"
455459
with tempfile.TemporaryDirectory() as tmpdirname:
456-
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=False)
460+
self.quantized_model.save_pretrained(tmpdirname, safe_serialization=safe_serialization)
457461
loaded_quantized_model = AutoModelForCausalLM.from_pretrained(tmpdirname, dtype=dtype, device_map=device)
458462
input_ids = self.tokenizer(self.input_text, return_tensors="pt").to(device)
459463

@@ -464,6 +468,48 @@ def test_serialization_expected_output(self):
464468
self.check_serialization_expected_output(self.device, self.EXPECTED_OUTPUT)
465469

466470

471+
@require_torchao
472+
@require_torchao_version_greater_or_equal("0.14.0")
473+
class TorchAoSafeSerializationTest(TorchAoSerializationTest):
474+
# called only once for all test in this class
475+
@classmethod
476+
def setUpClass(cls):
477+
cls.tokenizer = AutoTokenizer.from_pretrained(cls.model_name)
478+
cls.EXPECTED_OUTPUT = "What are we having for dinner?\n- 1. What is the temperature outside"
479+
480+
def tearDown(self):
481+
gc.collect()
482+
backend_empty_cache(torch_device)
483+
gc.collect()
484+
if hasattr(self, "quantized_model"):
485+
del self.quantized_model
486+
gc.collect()
487+
488+
test_params = (
489+
[
490+
(
491+
torchao.quantization.Float8DynamicActivationFloat8WeightConfig(),
492+
"What are we having for dinner?\n\nJess: (smiling) I",
493+
),
494+
(torchao.quantization.Float8WeightOnlyConfig(), "What are we having for dinner?\n\nJessica: (smiling)"),
495+
]
496+
if is_torchao_available()
497+
else []
498+
)
499+
500+
@parameterized.expand(test_params, skip_on_empty=True)
501+
def test_serialization_expected_output(self, config, expected_output):
502+
device = "cuda"
503+
self.quant_config = TorchAoConfig(config)
504+
self.quantized_model = AutoModelForCausalLM.from_pretrained(
505+
self.model_name,
506+
dtype=torch.bfloat16,
507+
device_map=device,
508+
quantization_config=self.quant_config,
509+
)
510+
self.check_serialization_expected_output(device, expected_output, safe_serialization=True)
511+
512+
467513
class TorchAoSerializationW8A8CPUTest(TorchAoSerializationTest):
468514
quant_scheme, quant_scheme_kwargs = "int8_dynamic_activation_int8_weight", {}
469515

@@ -500,7 +546,7 @@ def test_serialization_expected_output_on_accelerator(self):
500546

501547
@require_torch_accelerator
502548
class TorchAoSerializationAcceleratorTest(TorchAoSerializationTest):
503-
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32}
549+
quant_scheme, quant_scheme_kwargs = "int4_weight_only", {"group_size": 32, "version": 1}
504550
device = f"{torch_device}:0"
505551

506552
# called only once for all test in this class

0 commit comments

Comments
 (0)