24
24
from accelerate .test_utils .testing import run_command
25
25
from accelerate .utils import patch_environment
26
26
from datasets import Audio , DatasetDict , load_dataset
27
+ from parameterized import parameterized
27
28
from transformers import (
28
29
AutoModelForCausalLM ,
29
30
AutoModelForSeq2SeqLM ,
@@ -950,16 +951,31 @@ def setUp(self):
950
951
self .error_factor = 3
951
952
self .model_id = "hf-internal-testing/tiny-random-BloomForCausalLM"
952
953
self .tokenizer = AutoTokenizer .from_pretrained (self .model_id )
953
- self .inputs = self .tokenizer ("All I want is" , padding = True , return_tensors = "pt" ).to ("cuda" )
954
954
955
- def get_errors (self , bits = 4 , loftq_iter = 1 ):
955
+ def get_input (self , device ):
956
+ inputs = self .tokenizer ("All I want is" , padding = True , return_tensors = "pt" )
957
+ if device == "cuda" :
958
+ inputs = inputs .to ("cuda" )
959
+ return inputs
960
+
961
+ def get_base_model (self , model_id , device , ** kwargs ):
962
+ model = AutoModelForCausalLM .from_pretrained (model_id , ** kwargs ).eval ()
963
+ if device == "cuda" :
964
+ model = model .to ("cuda" )
965
+ return model
966
+
967
+ def get_errors (self , bits = 4 , loftq_iter = 1 , device = "cuda" ):
956
968
# Helper function that returns the quantization errors (MAE and MSE) when comparing the quantized LoRA model
957
969
# to the base model, vs the LoftQ quantized model to the base model. We expect the LoftQ quantized model to
958
970
# have less error than the normal LoRA quantized model. Since we compare logits, the observed error is
959
971
# already somewhat dampened because of the softmax.
960
- model = AutoModelForCausalLM .from_pretrained (self .model_id ).cuda ().eval ()
972
+ model = self .get_base_model (self .model_id , device )
973
+ if device == "cuda" :
974
+ model = model .to ("cuda" )
975
+
961
976
torch .manual_seed (0 )
962
- logits_base = model (** self .inputs ).logits
977
+ inputs = self .get_input (device )
978
+ logits_base = model (** inputs ).logits
963
979
# clean up
964
980
del model
965
981
gc .collect ()
@@ -976,21 +992,27 @@ def get_errors(self, bits=4, loftq_iter=1):
976
992
raise ValueError ("bits must be 4 or 8" )
977
993
978
994
quantized_model = get_peft_model (
979
- AutoModelForCausalLM . from_pretrained (self .model_id , device_map = "auto" , ** kwargs ). eval ( ),
995
+ self . get_base_model (self .model_id , device = None , ** kwargs ),
980
996
lora_config ,
981
997
)
982
998
torch .manual_seed (0 )
983
- logits_quantized = quantized_model (** self . inputs ).logits
999
+ logits_quantized = quantized_model (** inputs ).logits
984
1000
del quantized_model
985
1001
gc .collect ()
986
1002
torch .cuda .empty_cache ()
987
1003
988
1004
# logits from quantized LoRA model using LoftQ
989
1005
loftq_config = LoftQConfig (loftq_bits = bits , loftq_iter = loftq_iter )
990
1006
lora_config = LoraConfig (task_type = TaskType .CAUSAL_LM , init_lora_weights = "loftq" , loftq_config = loftq_config )
991
- loftq_model = get_peft_model (AutoModelForCausalLM .from_pretrained (self .model_id ).cuda ().eval (), lora_config )
1007
+ model = self .get_base_model (self .model_id , device )
1008
+ if device == "cuda" :
1009
+ model = model .to ("cuda" )
1010
+ loftq_model = get_peft_model (model , lora_config )
1011
+ if device == "cuda" :
1012
+ loftq_model = loftq_model .to ("cuda" )
1013
+
992
1014
torch .manual_seed (0 )
993
- logits_loftq = loftq_model (** self . inputs ).logits
1015
+ logits_loftq = loftq_model (** inputs ).logits
994
1016
del loftq_model
995
1017
gc .collect ()
996
1018
torch .cuda .empty_cache ()
@@ -1001,14 +1023,15 @@ def get_errors(self, bits=4, loftq_iter=1):
1001
1023
mse_loftq = torch .pow (logits_base - logits_loftq , 2 ).mean ()
1002
1024
return mae_quantized , mse_quantized , mae_loftq , mse_loftq
1003
1025
1004
- def test_bloomz_loftq_4bit (self ):
1026
+ @parameterized .expand (["cuda" , "cpu" ])
1027
+ def test_bloomz_loftq_4bit (self , device ):
1005
1028
# In this test, we compare the logits of the base model, the quantized LoRA model, and the quantized model
1006
1029
# using LoftQ. When quantizing, we expect a certain level of error. However, we expect the LoftQ quantized
1007
1030
# model to have less error than the normal LoRA quantized model. Note that when using normal LoRA, the
1008
1031
# quantization error is simply the error from quantization without LoRA, as LoRA is a no-op before training.
1009
1032
# We still apply LoRA for the test for consistency.
1010
1033
1011
- mae_quantized , mse_quantized , mae_loftq , mse_loftq = self .get_errors (bits = 4 )
1034
+ mae_quantized , mse_quantized , mae_loftq , mse_loftq = self .get_errors (bits = 4 , device = device )
1012
1035
# first, sanity check that all errors are > 0.0
1013
1036
self .assertTrue (mae_quantized > 0.0 )
1014
1037
self .assertTrue (mse_quantized > 0.0 )
@@ -1020,10 +1043,11 @@ def test_bloomz_loftq_4bit(self):
1020
1043
self .assertTrue (mae_loftq < mae_quantized / factor )
1021
1044
self .assertTrue (mse_loftq < mse_quantized / factor )
1022
1045
1023
- def test_bloomz_loftq_4bit_iter_5 (self ):
1046
+ @parameterized .expand (["cuda" , "cpu" ])
1047
+ def test_bloomz_loftq_4bit_iter_5 (self , device ):
1024
1048
# Same test as the previous one but with 5 iterations. We should expect the error to be even smaller with more
1025
1049
# iterations, but in practice the difference is not that large, at least not for this small base model.
1026
- mae_quantized , mse_quantized , mae_loftq , mse_loftq = self .get_errors (bits = 4 , loftq_iter = 5 )
1050
+ mae_quantized , mse_quantized , mae_loftq , mse_loftq = self .get_errors (bits = 4 , loftq_iter = 5 , device = device )
1027
1051
# first, sanity check that all errors are > 0.0
1028
1052
self .assertTrue (mae_quantized > 0.0 )
1029
1053
self .assertTrue (mse_quantized > 0.0 )
@@ -1034,14 +1058,15 @@ def test_bloomz_loftq_4bit_iter_5(self):
1034
1058
self .assertTrue (mae_loftq < mae_quantized / self .error_factor )
1035
1059
self .assertTrue (mse_loftq < mse_quantized / self .error_factor )
1036
1060
1037
- def test_bloomz_loftq_8bit (self ):
1061
+ @parameterized .expand (["cuda" , "cpu" ])
1062
+ def test_bloomz_loftq_8bit (self , device ):
1038
1063
# this currently does not work:
1039
1064
# https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499
1040
1065
if True : # TODO: remove as soon as the issue is fixed
1041
1066
return
1042
1067
1043
1068
# Same test as test_bloomz_loftq_4bit but with 8 bits.
1044
- mae_quantized , mse_quantized , mae_loftq , mse_loftq = self .get_errors (bits = 8 )
1069
+ mae_quantized , mse_quantized , mae_loftq , mse_loftq = self .get_errors (bits = 8 , device = device )
1045
1070
1046
1071
# first, sanity check that all errors are > 0.0
1047
1072
self .assertTrue (mae_quantized > 0.0 )
@@ -1053,14 +1078,15 @@ def test_bloomz_loftq_8bit(self):
1053
1078
self .assertTrue (mae_loftq < mae_quantized / self .error_factor )
1054
1079
self .assertTrue (mse_loftq < mse_quantized / self .error_factor )
1055
1080
1056
- def test_bloomz_loftq_8bit_iter_5 (self ):
1081
+ @parameterized .expand (["cuda" , "cpu" ])
1082
+ def test_bloomz_loftq_8bit_iter_5 (self , device ):
1057
1083
# this currently does not work:
1058
1084
# https://github.com/huggingface/peft/pull/1150#issuecomment-1838891499
1059
1085
if True : # TODO: remove as soon as the issue is fixed
1060
1086
return
1061
1087
1062
1088
# Same test as test_bloomz_loftq_4bit_iter_5 but with 8 bits.
1063
- mae_quantized , mse_quantized , mae_loftq , mse_loftq = self .get_errors (bits = 8 , loftq_iter = 5 )
1089
+ mae_quantized , mse_quantized , mae_loftq , mse_loftq = self .get_errors (bits = 8 , loftq_iter = 5 , device = device )
1064
1090
1065
1091
# first, sanity check that all errors are > 0.0
1066
1092
self .assertTrue (mae_quantized > 0.0 )
0 commit comments