Skip to content

Commit e3358c2

Browse files
committed
add more version checking
1 parent d4c4ac6 commit e3358c2

File tree

1 file changed

+15
-13
lines changed

1 file changed

+15
-13
lines changed

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def _linear_extra_repr(self):
9898
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
9999
]
100100

101+
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
102+
101103

102104
class TorchAoHfQuantizer(HfQuantizer):
103105
"""
@@ -160,12 +162,13 @@ def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool]
160162
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
161163
the safetensors format.
162164
"""
163-
if (
164-
type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS
165-
and safe_serialization
166-
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0")
167-
):
168-
return flatten_tensor_state_dict(model.state_dict())
165+
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
166+
if TORCHAO_VERSION >= version.parse("0.14.0"):
167+
return flatten_tensor_state_dict(model.state_dict())
168+
else:
169+
raise RuntimeError(
170+
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
171+
)
169172
else:
170173
return super().get_state_dict_and_metadata(model)
171174

@@ -316,9 +319,7 @@ def update_state_dict_with_metadata(self, state_dict, metadata):
316319
If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict
317320
from the provided state_dict and metadata.
318321
"""
319-
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0") and is_metadata_torchao(
320-
metadata
321-
):
322+
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata):
322323
return unflatten_tensor_state_dict(state_dict, metadata)
323324
else:
324325
return super().update_state_dict_with_metadata(state_dict, metadata)
@@ -341,13 +342,14 @@ def _process_model_after_weight_loading(self, model, **kwargs):
341342

342343
def is_serializable(self, safe_serialization=None) -> bool:
343344
if safe_serialization:
344-
_is_torchao_serializable = (
345-
type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS
346-
)
345+
_is_torchao_serializable = type(
346+
self.quantization_config.quant_type
347+
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
347348
if not _is_torchao_serializable:
348349
logger.warning(
349350
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \
350-
please set `safe_serialization` to False for {type(self.quantization_config.quant_type)}."
351+
and torchao version >= 0.14.0, please set `safe_serialization` to False for \
352+
{type(self.quantization_config.quant_type)} and {TORCHAO_VERSION}."
351353
)
352354
return _is_torchao_serializable
353355

0 commit comments

Comments
 (0)