1818import unittest
1919
2020from packaging import version
21+ from parameterized import parameterized
2122
2223from transformers import AutoModelForCausalLM , AutoTokenizer , TorchAoConfig
2324from transformers .testing_utils import (
3738 import torch
3839
3940if is_torchao_available ():
41+ import torchao
42+
4043 # renamed in torchao 0.7.0, please install the latest torchao
4144 from torchao .dtypes import (
4245 AffineQuantizedTensor ,
@@ -135,7 +138,7 @@ class TorchAoTest(unittest.TestCase):
135138 model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
136139 device = "cpu"
137140 quant_scheme_kwargs = (
138- {"group_size" : 32 , "layout" : Int4CPULayout ()}
141+ {"group_size" : 32 , "layout" : Int4CPULayout (), "version" : 1 }
139142 if is_torchao_available () and version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.8.0" )
140143 else {"group_size" : 32 }
141144 )
@@ -225,6 +228,7 @@ def test_include_input_output_embeddings(self):
225228 weight_dtype = weight_dtype ,
226229 granularity = granularity ,
227230 mapping_type = mapping_type ,
231+ version = 1 ,
228232 )
229233 config = ModuleFqnToConfig (
230234 {"_default" : None , "model.embed_tokens" : embedding_config , "lm_head" : embedding_config }
@@ -277,7 +281,7 @@ def test_per_module_config_skip(self):
277281@require_torch_accelerator
278282class TorchAoAcceleratorTest (TorchAoTest ):
279283 device = torch_device
280- quant_scheme_kwargs = {"group_size" : 32 }
284+ quant_scheme_kwargs = {"group_size" : 32 , "version" : 1 }
281285
282286 # called only once for all test in this class
283287 @classmethod
@@ -327,7 +331,7 @@ def test_int4wo_offload(self):
327331 "lm_head" : 0 ,
328332 }
329333
330- quant_config = TorchAoConfig ("int4_weight_only" , group_size = 32 )
334+ quant_config = TorchAoConfig ("int4_weight_only" , ** self . quant_scheme_kwargs )
331335
332336 quantized_model = AutoModelForCausalLM .from_pretrained (
333337 self .model_name ,
@@ -399,7 +403,7 @@ def test_autoquant(self):
399403
400404 check_autoquantized (self , quantized_model .model .layers [0 ].self_attn .v_proj )
401405
402- EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jane : (sighs )"
406+ EXPECTED_OUTPUT = "What are we having for dinner?\n \n Jessica : (smiling )"
403407 output = quantized_model .generate (
404408 ** input_ids , max_new_tokens = self .max_new_tokens , cache_implementation = "static"
405409 )
@@ -414,7 +418,7 @@ class TorchAoSerializationTest(unittest.TestCase):
414418 model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
415419 quant_scheme = "int4_weight_only"
416420 quant_scheme_kwargs = (
417- {"group_size" : 32 , "layout" : Int4CPULayout ()}
421+ {"group_size" : 32 , "layout" : Int4CPULayout (), "version" : 1 }
418422 if is_torchao_available () and version .parse (importlib .metadata .version ("torchao" )) >= version .parse ("0.8.0" )
419423 else {"group_size" : 32 }
420424 )
@@ -447,13 +451,13 @@ def test_original_model_expected_output(self):
447451
448452 self .assertEqual (self .tokenizer .decode (output [0 ], skip_special_tokens = True ), self .EXPECTED_OUTPUT )
449453
450- def check_serialization_expected_output (self , device , expected_output ):
454+ def check_serialization_expected_output (self , device , expected_output , safe_serialization = False ):
451455 """
452456 Test if we can serialize and load/infer the model again on the same device
453457 """
454458 dtype = torch .bfloat16 if self .quant_scheme == "int4_weight_only" else "auto"
455459 with tempfile .TemporaryDirectory () as tmpdirname :
456- self .quantized_model .save_pretrained (tmpdirname , safe_serialization = False )
460+ self .quantized_model .save_pretrained (tmpdirname , safe_serialization = safe_serialization )
457461 loaded_quantized_model = AutoModelForCausalLM .from_pretrained (tmpdirname , dtype = dtype , device_map = device )
458462 input_ids = self .tokenizer (self .input_text , return_tensors = "pt" ).to (device )
459463
@@ -464,6 +468,48 @@ def test_serialization_expected_output(self):
464468 self .check_serialization_expected_output (self .device , self .EXPECTED_OUTPUT )
465469
466470
471+ @require_torchao
472+ @require_torchao_version_greater_or_equal ("0.14.0" )
473+ class TorchAoSafeSerializationTest (TorchAoSerializationTest ):
474+ # called only once for all test in this class
475+ @classmethod
476+ def setUpClass (cls ):
477+ cls .tokenizer = AutoTokenizer .from_pretrained (cls .model_name )
478+ cls .EXPECTED_OUTPUT = "What are we having for dinner?\n - 1. What is the temperature outside"
479+
480+ def tearDown (self ):
481+ gc .collect ()
482+ backend_empty_cache (torch_device )
483+ gc .collect ()
484+ if hasattr (self , "quantized_model" ):
485+ del self .quantized_model
486+ gc .collect ()
487+
488+ test_params = (
489+ [
490+ (
491+ torchao .quantization .Float8DynamicActivationFloat8WeightConfig (),
492+ "What are we having for dinner?\n \n Jess: (smiling) I" ,
493+ ),
494+ (torchao .quantization .Float8WeightOnlyConfig (), "What are we having for dinner?\n \n Jessica: (smiling)" ),
495+ ]
496+ if is_torchao_available ()
497+ else []
498+ )
499+
500+ @parameterized .expand (test_params , skip_on_empty = True )
501+ def test_serialization_expected_output (self , config , expected_output ):
502+ device = "cuda"
503+ self .quant_config = TorchAoConfig (config )
504+ self .quantized_model = AutoModelForCausalLM .from_pretrained (
505+ self .model_name ,
506+ dtype = torch .bfloat16 ,
507+ device_map = device ,
508+ quantization_config = self .quant_config ,
509+ )
510+ self .check_serialization_expected_output (device , expected_output , safe_serialization = True )
511+
512+
467513class TorchAoSerializationW8A8CPUTest (TorchAoSerializationTest ):
468514 quant_scheme , quant_scheme_kwargs = "int8_dynamic_activation_int8_weight" , {}
469515
@@ -500,7 +546,7 @@ def test_serialization_expected_output_on_accelerator(self):
500546
501547@require_torch_accelerator
502548class TorchAoSerializationAcceleratorTest (TorchAoSerializationTest ):
503- quant_scheme , quant_scheme_kwargs = "int4_weight_only" , {"group_size" : 32 }
549+ quant_scheme , quant_scheme_kwargs = "int4_weight_only" , {"group_size" : 32 , "version" : 1 }
504550 device = f"{ torch_device } :0"
505551
506552 # called only once for all test in this class
0 commit comments