@@ -98,6 +98,8 @@ def _linear_extra_repr(self):
9898 torchao .quantization .Float8DynamicActivationFloat8WeightConfig ,
9999 ]
100100
101+ TORCHAO_VERSION = version .parse (importlib .metadata .version ("torchao" ))
102+
101103
102104class TorchAoHfQuantizer (HfQuantizer ):
103105 """
@@ -160,12 +162,13 @@ def get_state_dict_and_metadata(self, model, safe_serialization: Optional[bool]
160162 If the model is safe serializable, we flatten the state dict of tensor subclasses so that it is compatible with
161163 the safetensors format.
162164 """
163- if (
164- type (self .quantization_config .quant_type ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS
165- and safe_serialization
166- and version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.14.0" )
167- ):
168- return flatten_tensor_state_dict (model .state_dict ())
165+ if type (self .quantization_config .quant_type ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and safe_serialization :
166+ if TORCHAO_VERSION >= version .parse ("0.14.0" ):
167+ return flatten_tensor_state_dict (model .state_dict ())
168+ else :
169+ raise RuntimeError (
170+ f"In order to use safetensors with torchao, please use torchao version >= 0.14.0. Current version: { TORCHAO_VERSION } "
171+ )
169172 else :
170173 return super ().get_state_dict_and_metadata (model )
171174
@@ -316,9 +319,7 @@ def update_state_dict_with_metadata(self, state_dict, metadata):
316319 If the metadata contains torchao tensor subclass information, we reconstruct the tensor subclass state dict
317320 from the provided state_dict and metadata.
318321 """
319- if version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.14.0" ) and is_metadata_torchao (
320- metadata
321- ):
322+ if TORCHAO_VERSION >= version .parse ("0.14.0" ) and is_metadata_torchao (metadata ):
322323 return unflatten_tensor_state_dict (state_dict , metadata )
323324 else :
324325 return super ().update_state_dict_with_metadata (state_dict , metadata )
@@ -341,9 +342,9 @@ def _process_model_after_weight_loading(self, model, **kwargs):
341342
342343 def is_serializable (self , safe_serialization = None ) -> bool :
343344 if safe_serialization :
344- _is_torchao_serializable = (
345- type ( self .quantization_config .quant_type ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS
346- )
345+ _is_torchao_serializable = type (
346+ self .quantization_config .quant_type
347+ ) in SUPPORTED_SAFE_SERIALIZATION_CONFIGS and TORCHAO_VERSION >= version . parse ( "0.14.0" )
347348 if not _is_torchao_serializable :
348349 logger .warning (
349350 f"torchao quantized model only supports safe serialization for { SUPPORTED_SAFE_SERIALIZATION_CONFIGS } , \
0 commit comments