Skip to content

Commit 7f82f69

Browse files
committed
add more version checking
1 parent 58641ad commit 7f82f69

File tree

1 file changed

+13
-12
lines changed

1 file changed

+13
-12
lines changed

src/transformers/quantizers/quantizer_torchao.py

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,8 @@ def _linear_extra_repr(self):
9898
torchao.quantization.Float8DynamicActivationFloat8WeightConfig,
9999
]
100100

101+
TORCHAO_VERSION = version.parse(importlib.metadata.version("torchao"))
102+
101103

102104
class TorchAoHfQuantizer(HfQuantizer):
103105
"""
@@ -160,12 +162,13 @@ def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool]
160162
If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
161163
the safetensors format.
162164
"""
163-
if (
164-
type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS
165-
and safe_serialization
166-
and version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0")
167-
):
168-
return flatten_tensor_state_dict(model.state_dict())
165+
if type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization:
166+
if TORCHAO_VERSION >= version.parse("0.14.0"):
167+
return flatten_tensor_state_dict(model.state_dict())
168+
else:
169+
raise RuntimeError(
170+
f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: {TORCHAO_VERSION}"
171+
)
169172
else:
170173
return super().get_state_dict_and_metadata(model)
171174

@@ -316,9 +319,7 @@ def update_state_dict_with_metadata(self, state_dict, metadata):
316319
If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict
317320
from the provided state_dict and metadata.
318321
"""
319-
if version.parse(importlib.metadata.version("torchao")) >= version.parse("0.14.0") and is_metadata_torchao(
320-
metadata
321-
):
322+
if TORCHAO_VERSION >= version.parse("0.14.0") and is_metadata_torchao(metadata):
322323
return unflatten_tensor_state_dict(state_dict, metadata)
323324
else:
324325
return super().update_state_dict_with_metadata(state_dict, metadata)
@@ -341,9 +342,9 @@ def _process_model_after_weight_loading(self, model, **kwargs):
341342

342343
def is_serializable(self, safe_serialization=None) -> bool:
343344
if safe_serialization:
344-
_is_torchao_serializable = (
345-
type(self.quantization_config.quant_type) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS
346-
)
345+
_is_torchao_serializable = type(
346+
self.quantization_config.quant_type
347+
) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version.parse("0.14.0")
347348
if not _is_torchao_serializable:
348349
logger.warning(
349350
f"torchao quantized model only supports safe serialization for {SUPPORTED_SAFE_SERIALIZATION_CONFIGS}, \

0 commit comments

Comments
 (0)